tenflowers_core/
fallback.rs1#[cfg(feature = "gpu")]
7use crate::Device;
8use crate::{Result, Tensor, TensorError};
9use scirs2_core::num_traits;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12static AUTO_FALLBACK_ENABLED: AtomicBool = AtomicBool::new(true);
14
15#[derive(Debug, Clone)]
17pub struct FallbackConfig {
18 pub gpu_to_cpu: bool,
20 pub reduce_precision: bool,
22 pub memory_cleanup: bool,
24 pub max_retries: usize,
26 pub log_fallbacks: bool,
28}
29
30impl Default for FallbackConfig {
31 fn default() -> Self {
32 Self {
33 gpu_to_cpu: true,
34 reduce_precision: false,
35 memory_cleanup: true,
36 max_retries: 3,
37 log_fallbacks: true,
38 }
39 }
40}
41
42#[allow(static_mut_refs)]
44static mut GLOBAL_FALLBACK_CONFIG: Option<FallbackConfig> = None;
45static FALLBACK_CONFIG_INIT: std::sync::Once = std::sync::Once::new();
46
47#[allow(static_mut_refs)]
49pub fn get_fallback_config() -> FallbackConfig {
50 unsafe {
51 FALLBACK_CONFIG_INIT.call_once(|| {
52 GLOBAL_FALLBACK_CONFIG = Some(FallbackConfig::default());
53 });
54 GLOBAL_FALLBACK_CONFIG
55 .as_ref()
56 .expect("Fallback config should be initialized")
57 .clone()
58 }
59}
60
61#[allow(static_mut_refs)]
63pub fn set_fallback_config(config: FallbackConfig) {
64 unsafe {
65 GLOBAL_FALLBACK_CONFIG = Some(config);
66 }
67}
68
69pub fn set_auto_fallback_enabled(enabled: bool) {
71 AUTO_FALLBACK_ENABLED.store(enabled, Ordering::SeqCst);
72}
73
74pub fn is_auto_fallback_enabled() -> bool {
76 AUTO_FALLBACK_ENABLED.load(Ordering::SeqCst)
77}
78
79pub trait FallbackOperation<T> {
81 fn with_fallback(self) -> Result<T>;
83
84 fn fallback_to_cpu(self) -> Result<T>;
86}
87
88pub fn execute_binary_op_with_fallback<T, F>(
90 operation_name: &str,
91 tensor_a: &Tensor<T>,
92 tensor_b: &Tensor<T>,
93 gpu_op: F,
94 #[allow(unused_variables)] cpu_op: F,
95) -> Result<Tensor<T>>
96where
97 T: Clone
98 + Default
99 + scirs2_core::num_traits::Zero
100 + scirs2_core::num_traits::One
101 + Send
102 + Sync
103 + 'static
104 + bytemuck::Pod,
105 F: Fn(&Tensor<T>, &Tensor<T>) -> Result<Tensor<T>>,
106{
107 let config = get_fallback_config();
108
109 if !is_auto_fallback_enabled() {
110 return gpu_op(tensor_a, tensor_b);
111 }
112
113 match gpu_op(tensor_a, tensor_b) {
115 Ok(result) => Ok(result),
116 Err(error) => {
117 if config.log_fallbacks {
118 eprintln!("Operation '{operation_name}' failed: {error}. Attempting fallback...");
119 }
120
121 if error.supports_fallback() && config.gpu_to_cpu {
123 match (tensor_a.device(), tensor_b.device()) {
125 #[cfg(feature = "gpu")]
126 (Device::Gpu(_), _) | (_, Device::Gpu(_)) => {
127 if config.log_fallbacks {
128 eprintln!(
129 "Falling back to CPU execution for operation '{}'",
130 operation_name
131 );
132 }
133
134 let cpu_a = tensor_a.to_device(Device::Cpu)?;
136 let cpu_b = tensor_b.to_device(Device::Cpu)?;
137
138 match cpu_op(&cpu_a, &cpu_b) {
140 Ok(result) => {
141 if config.log_fallbacks {
142 eprintln!(
143 "CPU fallback successful for operation '{}'",
144 operation_name
145 );
146 }
147 Ok(result)
148 }
149 Err(cpu_error) => {
150 if config.log_fallbacks {
151 eprintln!(
152 "CPU fallback also failed for operation '{}': {}",
153 operation_name, cpu_error
154 );
155 }
156 Err(cpu_error)
157 }
158 }
159 }
160 _ => {
161 Err(error)
163 }
164 }
165 } else {
166 Err(error)
167 }
168 }
169 }
170}
171
172pub fn execute_unary_op_with_fallback<T, F>(
174 operation_name: &str,
175 tensor: &Tensor<T>,
176 gpu_op: F,
177 #[allow(unused_variables)] cpu_op: F,
178) -> Result<Tensor<T>>
179where
180 T: Clone
181 + Default
182 + scirs2_core::num_traits::Zero
183 + scirs2_core::num_traits::One
184 + Send
185 + Sync
186 + 'static
187 + bytemuck::Pod,
188 F: Fn(&Tensor<T>) -> Result<Tensor<T>>,
189{
190 let config = get_fallback_config();
191
192 if !is_auto_fallback_enabled() {
193 return gpu_op(tensor);
194 }
195
196 match gpu_op(tensor) {
198 Ok(result) => Ok(result),
199 Err(error) => {
200 if config.log_fallbacks {
201 eprintln!("Operation '{operation_name}' failed: {error}. Attempting fallback...");
202 }
203
204 if error.supports_fallback() && config.gpu_to_cpu {
206 #[cfg(feature = "gpu")]
208 return if let Device::Gpu(_) = tensor.device() {
209 if config.log_fallbacks {
210 eprintln!(
211 "Falling back to CPU execution for operation '{}'",
212 operation_name
213 );
214 }
215
216 let cpu_tensor = tensor.to_device(Device::Cpu)?;
218
219 match cpu_op(&cpu_tensor) {
221 Ok(result) => {
222 if config.log_fallbacks {
223 eprintln!(
224 "CPU fallback successful for operation '{}'",
225 operation_name
226 );
227 }
228 Ok(result)
229 }
230 Err(cpu_error) => {
231 if config.log_fallbacks {
232 eprintln!(
233 "CPU fallback also failed for operation '{}': {}",
234 operation_name, cpu_error
235 );
236 }
237 Err(cpu_error)
238 }
239 }
240 } else {
241 Err(error)
243 };
244
245 #[cfg(not(feature = "gpu"))]
246 return Err(error);
247 } else {
248 Err(error)
249 }
250 }
251 }
252}
253
254pub fn cleanup_memory_and_retry<T, F>(operation: F, max_retries: usize) -> Result<T>
256where
257 F: Fn() -> Result<T>,
258{
259 let mut attempt = 0;
260
261 loop {
262 match operation() {
263 Ok(result) => return Ok(result),
264 Err(error) => {
265 attempt += 1;
266
267 if attempt >= max_retries {
268 return Err(error);
269 }
270
271 match &error {
273 TensorError::AllocationError { .. } | TensorError::ResourceExhausted { .. } => {
274 eprintln!("Memory error detected, attempting cleanup (attempt {attempt}/{max_retries})");
275
276 #[cfg(feature = "gpu")]
278 {
279 crate::memory::global_monitor().clear();
281 }
282
283 std::hint::black_box(Vec::<u8>::new());
285
286 std::thread::sleep(std::time::Duration::from_millis(100));
288 }
289 _ => {
290 return Err(error);
292 }
293 }
294 }
295 }
296 }
297}
298
299pub struct FallbackWrapper<T> {
301 result: Result<T>,
302 operation_name: String,
303}
304
305impl<T> FallbackWrapper<T> {
306 pub fn new(result: Result<T>, operation_name: &str) -> Self {
307 Self {
308 result,
309 operation_name: operation_name.to_string(),
310 }
311 }
312
313 pub fn with_cpu_fallback<F>(self, cpu_fallback: F) -> Result<T>
314 where
315 F: FnOnce() -> Result<T>,
316 {
317 match self.result {
318 Ok(result) => Ok(result),
319 Err(error) => {
320 if error.supports_fallback() && is_auto_fallback_enabled() {
321 let config = get_fallback_config();
322 if config.log_fallbacks {
323 eprintln!(
324 "Attempting CPU fallback for operation '{}'",
325 self.operation_name
326 );
327 }
328 cpu_fallback()
329 } else {
330 Err(error)
331 }
332 }
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::{DType, Device, Tensor};
341
342 #[test]
343 fn test_fallback_config() {
344 let config = FallbackConfig::default();
345 assert!(config.gpu_to_cpu);
346 assert!(config.memory_cleanup);
347 assert_eq!(config.max_retries, 3);
348 }
349
350 #[test]
351 fn test_auto_fallback_flag() {
352 assert!(is_auto_fallback_enabled()); set_auto_fallback_enabled(false);
355 assert!(!is_auto_fallback_enabled());
356
357 set_auto_fallback_enabled(true);
358 assert!(is_auto_fallback_enabled());
359 }
360
361 #[test]
362 fn test_fallback_wrapper() {
363 let success_result: Result<i32> = Ok(42);
364 let wrapper = FallbackWrapper::new(success_result, "test_op");
365
366 let result = wrapper.with_cpu_fallback(|| Ok(100));
367 assert_eq!(result.expect("test: operation should succeed"), 42);
368 }
369
370 #[test]
371 fn test_error_supports_fallback() {
372 let gpu_error = TensorError::unsupported_device("test", "gpu:0", true);
373 assert!(gpu_error.supports_fallback());
374
375 let shape_error = TensorError::shape_mismatch("test", "[2, 2]", "[3, 3]");
376 assert!(!shape_error.supports_fallback());
377 }
378}