Skip to main content

tenflowers_core/
fallback.rs

1//! Automatic fallback mechanisms for operation recovery
2//!
3//! This module provides utilities for automatic fallback when operations fail,
4//! particularly for GPU-to-CPU fallback scenarios.
5
6#[cfg(feature = "gpu")]
7use crate::Device;
8use crate::{Result, Tensor, TensorError};
9use scirs2_core::num_traits;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12/// Global flag to enable/disable automatic fallback
13static AUTO_FALLBACK_ENABLED: AtomicBool = AtomicBool::new(true);
14
15/// Configuration for fallback behavior
16#[derive(Debug, Clone)]
17pub struct FallbackConfig {
18    /// Enable automatic GPU-to-CPU fallback
19    pub gpu_to_cpu: bool,
20    /// Enable automatic precision reduction
21    pub reduce_precision: bool,
22    /// Enable memory cleanup and retry
23    pub memory_cleanup: bool,
24    /// Maximum number of retry attempts
25    pub max_retries: usize,
26    /// Log fallback attempts
27    pub log_fallbacks: bool,
28}
29
30impl Default for FallbackConfig {
31    fn default() -> Self {
32        Self {
33            gpu_to_cpu: true,
34            reduce_precision: false,
35            memory_cleanup: true,
36            max_retries: 3,
37            log_fallbacks: true,
38        }
39    }
40}
41
42/// Global fallback configuration
43#[allow(static_mut_refs)]
44static mut GLOBAL_FALLBACK_CONFIG: Option<FallbackConfig> = None;
45static FALLBACK_CONFIG_INIT: std::sync::Once = std::sync::Once::new();
46
47/// Get the global fallback configuration
48#[allow(static_mut_refs)]
49pub fn get_fallback_config() -> FallbackConfig {
50    unsafe {
51        FALLBACK_CONFIG_INIT.call_once(|| {
52            GLOBAL_FALLBACK_CONFIG = Some(FallbackConfig::default());
53        });
54        GLOBAL_FALLBACK_CONFIG
55            .as_ref()
56            .expect("Fallback config should be initialized")
57            .clone()
58    }
59}
60
61/// Set the global fallback configuration
62#[allow(static_mut_refs)]
63pub fn set_fallback_config(config: FallbackConfig) {
64    unsafe {
65        GLOBAL_FALLBACK_CONFIG = Some(config);
66    }
67}
68
69/// Enable or disable automatic fallback globally
70pub fn set_auto_fallback_enabled(enabled: bool) {
71    AUTO_FALLBACK_ENABLED.store(enabled, Ordering::SeqCst);
72}
73
74/// Check if automatic fallback is enabled
75pub fn is_auto_fallback_enabled() -> bool {
76    AUTO_FALLBACK_ENABLED.load(Ordering::SeqCst)
77}
78
79/// Trait for operations that support fallback
80pub trait FallbackOperation<T> {
81    /// Execute the operation with automatic fallback
82    fn with_fallback(self) -> Result<T>;
83
84    /// Execute the operation on CPU as fallback
85    fn fallback_to_cpu(self) -> Result<T>;
86}
87
88/// Execute a binary operation with automatic fallback
89pub fn execute_binary_op_with_fallback<T, F>(
90    operation_name: &str,
91    tensor_a: &Tensor<T>,
92    tensor_b: &Tensor<T>,
93    gpu_op: F,
94    #[allow(unused_variables)] cpu_op: F,
95) -> Result<Tensor<T>>
96where
97    T: Clone
98        + Default
99        + scirs2_core::num_traits::Zero
100        + scirs2_core::num_traits::One
101        + Send
102        + Sync
103        + 'static
104        + bytemuck::Pod,
105    F: Fn(&Tensor<T>, &Tensor<T>) -> Result<Tensor<T>>,
106{
107    let config = get_fallback_config();
108
109    if !is_auto_fallback_enabled() {
110        return gpu_op(tensor_a, tensor_b);
111    }
112
113    // Try the primary operation first
114    match gpu_op(tensor_a, tensor_b) {
115        Ok(result) => Ok(result),
116        Err(error) => {
117            if config.log_fallbacks {
118                eprintln!("Operation '{operation_name}' failed: {error}. Attempting fallback...");
119            }
120
121            // Check if this error supports fallback
122            if error.supports_fallback() && config.gpu_to_cpu {
123                // Try to move tensors to CPU and retry
124                match (tensor_a.device(), tensor_b.device()) {
125                    #[cfg(feature = "gpu")]
126                    (Device::Gpu(_), _) | (_, Device::Gpu(_)) => {
127                        if config.log_fallbacks {
128                            eprintln!(
129                                "Falling back to CPU execution for operation '{}'",
130                                operation_name
131                            );
132                        }
133
134                        // Move tensors to CPU
135                        let cpu_a = tensor_a.to_device(Device::Cpu)?;
136                        let cpu_b = tensor_b.to_device(Device::Cpu)?;
137
138                        // Execute on CPU
139                        match cpu_op(&cpu_a, &cpu_b) {
140                            Ok(result) => {
141                                if config.log_fallbacks {
142                                    eprintln!(
143                                        "CPU fallback successful for operation '{}'",
144                                        operation_name
145                                    );
146                                }
147                                Ok(result)
148                            }
149                            Err(cpu_error) => {
150                                if config.log_fallbacks {
151                                    eprintln!(
152                                        "CPU fallback also failed for operation '{}': {}",
153                                        operation_name, cpu_error
154                                    );
155                                }
156                                Err(cpu_error)
157                            }
158                        }
159                    }
160                    _ => {
161                        // Already on CPU or other device, can't fallback further
162                        Err(error)
163                    }
164                }
165            } else {
166                Err(error)
167            }
168        }
169    }
170}
171
172/// Execute a unary operation with automatic fallback
173pub fn execute_unary_op_with_fallback<T, F>(
174    operation_name: &str,
175    tensor: &Tensor<T>,
176    gpu_op: F,
177    #[allow(unused_variables)] cpu_op: F,
178) -> Result<Tensor<T>>
179where
180    T: Clone
181        + Default
182        + scirs2_core::num_traits::Zero
183        + scirs2_core::num_traits::One
184        + Send
185        + Sync
186        + 'static
187        + bytemuck::Pod,
188    F: Fn(&Tensor<T>) -> Result<Tensor<T>>,
189{
190    let config = get_fallback_config();
191
192    if !is_auto_fallback_enabled() {
193        return gpu_op(tensor);
194    }
195
196    // Try the primary operation first
197    match gpu_op(tensor) {
198        Ok(result) => Ok(result),
199        Err(error) => {
200            if config.log_fallbacks {
201                eprintln!("Operation '{operation_name}' failed: {error}. Attempting fallback...");
202            }
203
204            // Check if this error supports fallback
205            if error.supports_fallback() && config.gpu_to_cpu {
206                // Try to move tensor to CPU and retry
207                #[cfg(feature = "gpu")]
208                return if let Device::Gpu(_) = tensor.device() {
209                    if config.log_fallbacks {
210                        eprintln!(
211                            "Falling back to CPU execution for operation '{}'",
212                            operation_name
213                        );
214                    }
215
216                    // Move tensor to CPU
217                    let cpu_tensor = tensor.to_device(Device::Cpu)?;
218
219                    // Execute on CPU
220                    match cpu_op(&cpu_tensor) {
221                        Ok(result) => {
222                            if config.log_fallbacks {
223                                eprintln!(
224                                    "CPU fallback successful for operation '{}'",
225                                    operation_name
226                                );
227                            }
228                            Ok(result)
229                        }
230                        Err(cpu_error) => {
231                            if config.log_fallbacks {
232                                eprintln!(
233                                    "CPU fallback also failed for operation '{}': {}",
234                                    operation_name, cpu_error
235                                );
236                            }
237                            Err(cpu_error)
238                        }
239                    }
240                } else {
241                    // Already on CPU or other device, can't fallback further
242                    Err(error)
243                };
244
245                #[cfg(not(feature = "gpu"))]
246                return Err(error);
247            } else {
248                Err(error)
249            }
250        }
251    }
252}
253
254/// Memory cleanup utility for fallback scenarios
255pub fn cleanup_memory_and_retry<T, F>(operation: F, max_retries: usize) -> Result<T>
256where
257    F: Fn() -> Result<T>,
258{
259    let mut attempt = 0;
260
261    loop {
262        match operation() {
263            Ok(result) => return Ok(result),
264            Err(error) => {
265                attempt += 1;
266
267                if attempt >= max_retries {
268                    return Err(error);
269                }
270
271                // Check if this is a memory-related error
272                match &error {
273                    TensorError::AllocationError { .. } | TensorError::ResourceExhausted { .. } => {
274                        eprintln!("Memory error detected, attempting cleanup (attempt {attempt}/{max_retries})");
275
276                        // Trigger garbage collection if available
277                        #[cfg(feature = "gpu")]
278                        {
279                            // Clear GPU memory pools
280                            crate::memory::global_monitor().clear();
281                        }
282
283                        // Force garbage collection
284                        std::hint::black_box(Vec::<u8>::new());
285
286                        // Short delay before retry
287                        std::thread::sleep(std::time::Duration::from_millis(100));
288                    }
289                    _ => {
290                        // Not a memory error, don't retry
291                        return Err(error);
292                    }
293                }
294            }
295        }
296    }
297}
298
299/// Wrapper for automatic fallback of results
300pub struct FallbackWrapper<T> {
301    result: Result<T>,
302    operation_name: String,
303}
304
305impl<T> FallbackWrapper<T> {
306    pub fn new(result: Result<T>, operation_name: &str) -> Self {
307        Self {
308            result,
309            operation_name: operation_name.to_string(),
310        }
311    }
312
313    pub fn with_cpu_fallback<F>(self, cpu_fallback: F) -> Result<T>
314    where
315        F: FnOnce() -> Result<T>,
316    {
317        match self.result {
318            Ok(result) => Ok(result),
319            Err(error) => {
320                if error.supports_fallback() && is_auto_fallback_enabled() {
321                    let config = get_fallback_config();
322                    if config.log_fallbacks {
323                        eprintln!(
324                            "Attempting CPU fallback for operation '{}'",
325                            self.operation_name
326                        );
327                    }
328                    cpu_fallback()
329                } else {
330                    Err(error)
331                }
332            }
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use crate::{DType, Device, Tensor};
341
342    #[test]
343    fn test_fallback_config() {
344        let config = FallbackConfig::default();
345        assert!(config.gpu_to_cpu);
346        assert!(config.memory_cleanup);
347        assert_eq!(config.max_retries, 3);
348    }
349
350    #[test]
351    fn test_auto_fallback_flag() {
352        assert!(is_auto_fallback_enabled()); // Default is true
353
354        set_auto_fallback_enabled(false);
355        assert!(!is_auto_fallback_enabled());
356
357        set_auto_fallback_enabled(true);
358        assert!(is_auto_fallback_enabled());
359    }
360
361    #[test]
362    fn test_fallback_wrapper() {
363        let success_result: Result<i32> = Ok(42);
364        let wrapper = FallbackWrapper::new(success_result, "test_op");
365
366        let result = wrapper.with_cpu_fallback(|| Ok(100));
367        assert_eq!(result.expect("test: operation should succeed"), 42);
368    }
369
370    #[test]
371    fn test_error_supports_fallback() {
372        let gpu_error = TensorError::unsupported_device("test", "gpu:0", true);
373        assert!(gpu_error.supports_fallback());
374
375        let shape_error = TensorError::shape_mismatch("test", "[2, 2]", "[3, 3]");
376        assert!(!shape_error.supports_fallback());
377    }
378}