Skip to main content

torsh_backend/
backend.rs

1//! Core backend trait and implementations
2
3use crate::memory::MemoryManager;
4use crate::profiler::Profiler;
5use crate::{Buffer, BufferDescriptor, Device, Kernel, KernelDescriptor};
6use torsh_core::{device::DeviceType, dtype::DType, error::TorshError};
7
8#[cfg(not(feature = "std"))]
9use alloc::{boxed::Box, string::String, vec::Vec};
10
11/// Result type for backend operations  
12pub type BackendResult<T> = Result<T, TorshError>;
13
14// BackendError removed - use TorshError directly
15
16/// Core backend information and capabilities
17pub trait BackendCore: Send + Sync + std::fmt::Debug {
18    /// Get the device type this backend supports
19    fn device_type(&self) -> DeviceType;
20
21    /// Get the name of this backend
22    fn name(&self) -> &str;
23
24    /// Check if this backend is available on the current system
25    fn is_available(&self) -> BackendResult<bool>;
26
27    /// Get backend-specific capabilities
28    fn capabilities(&self) -> BackendCapabilities;
29
30    /// Get performance hints for optimization
31    fn performance_hints(&self) -> PerformanceHints;
32}
33
34/// Backend lifecycle management
35#[async_trait::async_trait]
36pub trait BackendLifecycle: Send + Sync {
37    /// Initialize the backend
38    async fn initialize(&mut self) -> BackendResult<()>;
39
40    /// Shutdown the backend and cleanup resources
41    async fn shutdown(&mut self) -> BackendResult<()>;
42
43    /// Check if the backend is initialized
44    fn is_initialized(&self) -> bool;
45}
46
47/// Device management operations
48pub trait BackendDeviceManager: Send + Sync {
49    /// Get available devices for this backend
50    fn devices(&self) -> BackendResult<Vec<Device>>;
51
52    /// Get the default device
53    fn default_device(&self) -> BackendResult<Device>;
54
55    /// Create a device from an ID
56    fn create_device(&self, device_id: usize) -> BackendResult<Device>;
57
58    /// Get device count
59    fn device_count(&self) -> BackendResult<usize>;
60
61    /// Check if a device is available
62    fn is_device_available(&self, device_id: usize) -> bool;
63}
64
65/// Resource creation and management with enhanced lifetime safety
66pub trait BackendResourceManager: Send + Sync {
67    /// Create a buffer on the specified device
68    fn create_buffer(
69        &self,
70        device: &Device,
71        descriptor: &BufferDescriptor,
72    ) -> BackendResult<Buffer>;
73
74    /// Create a kernel from source or bytecode
75    fn create_kernel(
76        &self,
77        device: &Device,
78        descriptor: &KernelDescriptor,
79    ) -> BackendResult<Kernel>;
80
81    /// Get the memory manager for a device with better lifetime bounds
82    fn memory_manager(
83        &self,
84        device: &Device,
85    ) -> BackendResult<Box<dyn MemoryManager + Send + Sync>>;
86
87    /// Get the profiler for performance analysis with better lifetime bounds
88    fn profiler(&self) -> BackendResult<Box<dyn Profiler + Send + Sync>>;
89
90    /// Create a buffer with automatic cleanup (replaces generic scoped resource for object safety)
91    fn create_scoped_buffer(
92        &self,
93        device: &Device,
94        descriptor: &BufferDescriptor,
95    ) -> BackendResult<Buffer>;
96}
97
98/// Advanced resource management operations with generic support (separate trait for type safety)
99pub trait BackendAdvancedResourceManager: Send + Sync {
100    /// Create a resource with automatic cleanup and type safety
101    fn create_resource_with_cleanup<T, F>(
102        &self,
103        device: &Device,
104        factory: F,
105        cleanup: impl FnOnce(&T) + Send + 'static,
106    ) -> BackendResult<ManagedResource<T>>
107    where
108        T: Send + Sync + 'static,
109        F: FnOnce(&Device) -> BackendResult<T>;
110}
111
112/// Execution operations
113#[async_trait::async_trait]
114pub trait BackendExecutor: Send + Sync {
115    /// Synchronize operations on a device (wait for completion)
116    async fn synchronize(&self, device: &Device) -> BackendResult<()>;
117
118    /// Copy data between buffers
119    async fn copy_buffer(
120        &self,
121        src: &Buffer,
122        dst: &Buffer,
123        src_offset: usize,
124        dst_offset: usize,
125        size: usize,
126    ) -> BackendResult<()>;
127
128    /// Copy data from host memory to device buffer
129    async fn copy_to_device(
130        &self,
131        src: &[u8],
132        dst: &Buffer,
133        dst_offset: usize,
134    ) -> BackendResult<()>;
135
136    /// Copy data from device buffer to host memory
137    async fn copy_from_device(
138        &self,
139        src: &Buffer,
140        dst: &mut [u8],
141        src_offset: usize,
142    ) -> BackendResult<()>;
143
144    /// Execute a kernel with the given parameters
145    async fn execute_kernel(
146        &self,
147        kernel: &Kernel,
148        buffers: &[&Buffer],
149        uniform_data: &[u8],
150        workgroup_size: (u32, u32, u32),
151        workgroup_count: (u32, u32, u32),
152    ) -> BackendResult<()>;
153}
154
155/// Specialized operations support
156pub trait BackendOperations: Send + Sync {
157    /// Get FFT operations for this backend
158    fn fft_ops(&self) -> Box<dyn crate::fft::FftOps>;
159
160    /// Get convolution operations for this backend
161    fn convolution_ops(&self) -> Box<dyn crate::convolution::ConvolutionOps>;
162
163    /// Get RNN operations for this backend
164    fn rnn_ops(&self) -> Box<dyn crate::rnn::RnnOps>;
165
166    /// Get sparse operations for this backend
167    fn sparse_ops(&self) -> Box<dyn crate::sparse_ops::SparseOps<f32>>;
168
169    /// Get quantization operations for this backend
170    fn quantization_ops(&self) -> Box<dyn crate::quantization::QuantizationOps>;
171
172    /// Get operations bundle for efficient access to multiple operation types
173    fn operations_bundle(&self) -> OperationsBundle;
174}
175
176/// The main backend trait that combines all backend functionality
177pub trait Backend:
178    BackendCore
179    + BackendLifecycle
180    + BackendDeviceManager
181    + BackendResourceManager
182    + BackendExecutor
183    + BackendOperations
184    + BackendOps
185{
186    /// Get a reference to this backend as BackendCore
187    fn as_core(&self) -> &dyn BackendCore;
188
189    /// Get a mutable reference to this backend as BackendLifecycle
190    fn as_lifecycle(&mut self) -> &mut dyn BackendLifecycle;
191
192    /// Get a reference to this backend as BackendDeviceManager
193    fn as_device_manager(&self) -> &dyn BackendDeviceManager;
194
195    /// Get a reference to this backend as BackendResourceManager
196    fn as_resource_manager(&self) -> &dyn BackendResourceManager;
197
198    /// Get a reference to this backend as BackendExecutor
199    fn as_executor(&self) -> &dyn BackendExecutor;
200
201    /// Get a reference to this backend as BackendOperations
202    fn as_operations(&self) -> &dyn BackendOperations;
203}
204
205/// RAII wrapper for backend resources with automatic cleanup and better type safety
206pub struct ScopedResource<'a, T> {
207    resource: Option<T>,
208    cleanup: Option<Box<dyn FnOnce(T) + Send + 'a>>,
209}
210
211impl<'a, T> ScopedResource<'a, T> {
212    /// Create a new scoped resource
213    pub fn new(resource: T) -> Self {
214        Self {
215            resource: Some(resource),
216            cleanup: None,
217        }
218    }
219
220    /// Create a new scoped resource with custom cleanup
221    pub fn new_with_cleanup<F>(resource: T, cleanup: F) -> Self
222    where
223        F: FnOnce(T) + Send + 'a,
224    {
225        Self {
226            resource: Some(resource),
227            cleanup: Some(Box::new(cleanup)),
228        }
229    }
230
231    /// Get a reference to the resource
232    pub fn get(&self) -> Option<&T> {
233        self.resource.as_ref()
234    }
235
236    /// Get a mutable reference to the resource
237    pub fn get_mut(&mut self) -> Option<&mut T> {
238        self.resource.as_mut()
239    }
240
241    /// Take ownership of the resource (prevents cleanup)
242    pub fn take(mut self) -> Option<T> {
243        self.cleanup = None; // Prevent cleanup
244        self.resource.take()
245    }
246
247    /// Execute a function with the resource, ensuring cleanup happens even if function panics
248    pub fn with_resource<F, R>(&self, f: F) -> Option<R>
249    where
250        F: FnOnce(&T) -> R,
251    {
252        self.resource.as_ref().map(f)
253    }
254
255    /// Check if the resource is available
256    pub fn is_available(&self) -> bool {
257        self.resource.is_some()
258    }
259}
260
261impl<'a, T> Drop for ScopedResource<'a, T> {
262    fn drop(&mut self) {
263        if let (Some(resource), Some(cleanup)) = (self.resource.take(), self.cleanup.take()) {
264            cleanup(resource);
265        }
266    }
267}
268
269/// Managed resource with automatic cleanup (no lifetime parameter for better usability)
270pub struct ManagedResource<T> {
271    resource: Option<T>,
272    cleanup: Option<Box<dyn FnOnce(&T) + Send + 'static>>,
273}
274
275impl<T> ManagedResource<T> {
276    /// Create a new managed resource
277    pub fn new(resource: T) -> Self {
278        Self {
279            resource: Some(resource),
280            cleanup: None,
281        }
282    }
283
284    /// Create a new managed resource with custom cleanup
285    pub fn new_with_cleanup<F>(resource: T, cleanup: F) -> Self
286    where
287        F: FnOnce(&T) + Send + 'static,
288    {
289        Self {
290            resource: Some(resource),
291            cleanup: Some(Box::new(cleanup)),
292        }
293    }
294
295    /// Get a reference to the resource
296    pub fn get(&self) -> Option<&T> {
297        self.resource.as_ref()
298    }
299
300    /// Get a mutable reference to the resource
301    pub fn get_mut(&mut self) -> Option<&mut T> {
302        self.resource.as_mut()
303    }
304
305    /// Take ownership of the resource (prevents cleanup)
306    pub fn take(mut self) -> Option<T> {
307        self.cleanup = None; // Prevent cleanup
308        self.resource.take()
309    }
310
311    /// Execute a function with the resource, ensuring cleanup happens even if function panics
312    pub fn with_resource<F, R>(&self, f: F) -> Option<R>
313    where
314        F: FnOnce(&T) -> R,
315    {
316        self.resource.as_ref().map(f)
317    }
318
319    /// Check if the resource is available
320    pub fn is_available(&self) -> bool {
321        self.resource.is_some()
322    }
323}
324
325impl<T> Drop for ManagedResource<T> {
326    fn drop(&mut self) {
327        if let (Some(resource), Some(cleanup)) = (self.resource.as_ref(), self.cleanup.take()) {
328            cleanup(resource);
329        }
330    }
331}
332
333// Ensure ManagedResource is Send and Sync when T is Send and Sync
334unsafe impl<T: Send> Send for ManagedResource<T> {}
335unsafe impl<T: Sync> Sync for ManagedResource<T> {}
336
337/// Bundle of operations for efficient access
338pub struct OperationsBundle {
339    pub fft: Box<dyn crate::fft::FftOps>,
340    pub convolution: Box<dyn crate::convolution::ConvolutionOps>,
341    pub rnn: Box<dyn crate::rnn::RnnOps>,
342    pub sparse: Box<dyn crate::sparse_ops::SparseOps<f32>>,
343    pub quantization: Box<dyn crate::quantization::QuantizationOps>,
344}
345
346impl OperationsBundle {
347    /// Create a new operations bundle
348    pub fn new(
349        fft: Box<dyn crate::fft::FftOps>,
350        convolution: Box<dyn crate::convolution::ConvolutionOps>,
351        rnn: Box<dyn crate::rnn::RnnOps>,
352        sparse: Box<dyn crate::sparse_ops::SparseOps<f32>>,
353        quantization: Box<dyn crate::quantization::QuantizationOps>,
354    ) -> Self {
355        Self {
356            fft,
357            convolution,
358            rnn,
359            sparse,
360            quantization,
361        }
362    }
363}
364
365/// Backend capabilities description with enhanced extensibility
366#[derive(Debug, Clone)]
367pub struct BackendCapabilities {
368    /// Maximum buffer size in bytes
369    pub max_buffer_size: usize,
370
371    /// Maximum number of compute units
372    pub max_compute_units: usize,
373
374    /// Maximum workgroup size
375    pub max_workgroup_size: (u32, u32, u32),
376
377    /// Supported data types
378    pub supported_dtypes: Vec<DType>,
379
380    /// Whether the backend supports async operations
381    pub supports_async: bool,
382
383    /// Whether the backend supports unified memory
384    pub supports_unified_memory: bool,
385
386    /// Whether the backend supports sub-buffers
387    pub supports_sub_buffers: bool,
388
389    /// Whether the backend supports kernel compilation caching
390    pub supports_kernel_caching: bool,
391
392    /// Memory bandwidth in GB/s
393    pub memory_bandwidth_gbps: f32,
394
395    /// Compute throughput in GFLOPS
396    pub compute_throughput_gflops: f32,
397
398    /// Extended capabilities for better extensibility
399    pub extended_capabilities: ExtendedCapabilities,
400}
401
402/// Extended capabilities for future extensibility
403#[derive(Debug, Clone)]
404pub struct ExtendedCapabilities {
405    /// Supported tensor shapes (None means no limits)
406    pub max_tensor_dims: Option<usize>,
407
408    /// Supported precision modes
409    pub precision_modes: Vec<PrecisionMode>,
410
411    /// Hardware-specific features
412    pub hardware_features: Vec<HardwareFeature>,
413
414    /// Memory hierarchy information
415    pub memory_hierarchy: MemoryHierarchy,
416
417    /// Execution model capabilities
418    pub execution_model: ExecutionModel,
419
420    /// Custom capabilities for backend-specific features
421    pub custom_capabilities: std::collections::HashMap<String, CapabilityValue>,
422}
423
424/// Precision modes supported by the backend
425#[derive(Debug, Clone, PartialEq)]
426pub enum PrecisionMode {
427    /// 16-bit floating point
428    F16,
429    /// 32-bit floating point
430    F32,
431    /// 64-bit floating point
432    F64,
433    /// Mixed precision (automatic)
434    Mixed,
435    /// Custom precision with bits
436    Custom(u8),
437}
438
439/// Hardware-specific features
440#[derive(Debug, Clone, PartialEq)]
441pub enum HardwareFeature {
442    /// Tensor cores (like CUDA Tensor Cores)
443    TensorCores,
444    /// Vector processing units
445    VectorUnits,
446    /// Shared memory
447    SharedMemory,
448    /// Constant memory
449    ConstantMemory,
450    /// Atomic operations
451    AtomicOperations,
452    /// Cooperative groups
453    CooperativeGroups,
454    /// Dynamic parallelism
455    DynamicParallelism,
456    /// Custom feature
457    Custom(String),
458}
459
460/// Memory hierarchy information
461#[derive(Debug, Clone, Default)]
462pub struct MemoryHierarchy {
463    /// L1 cache size in bytes
464    pub l1_cache_size: Option<usize>,
465    /// L2 cache size in bytes
466    pub l2_cache_size: Option<usize>,
467    /// L3 cache size in bytes
468    pub l3_cache_size: Option<usize>,
469    /// Shared memory size in bytes
470    pub shared_memory_size: Option<usize>,
471    /// Memory access latency in cycles
472    pub memory_latency_cycles: Option<u32>,
473    /// Memory bandwidth per core in GB/s
474    pub memory_bandwidth_per_core: Option<f32>,
475}
476
477/// Execution model capabilities
478#[derive(Debug, Clone)]
479pub struct ExecutionModel {
480    /// Whether the backend supports SIMD operations
481    pub supports_simd: bool,
482    /// Whether the backend supports SIMT operations
483    pub supports_simt: bool,
484    /// Whether the backend supports task parallelism
485    pub supports_task_parallelism: bool,
486    /// Whether the backend supports data parallelism
487    pub supports_data_parallelism: bool,
488    /// Maximum concurrent streams/queues
489    pub max_concurrent_streams: Option<u32>,
490    /// Whether the backend supports out-of-order execution
491    pub supports_out_of_order: bool,
492}
493
494/// Capability value for custom capabilities
495#[derive(Debug, Clone)]
496pub enum CapabilityValue {
497    Bool(bool),
498    Int(i64),
499    Float(f64),
500    String(String),
501    List(Vec<CapabilityValue>),
502}
503
504impl Default for ExtendedCapabilities {
505    fn default() -> Self {
506        Self {
507            max_tensor_dims: Some(8),
508            precision_modes: vec![PrecisionMode::F32],
509            hardware_features: vec![],
510            memory_hierarchy: MemoryHierarchy::default(),
511            execution_model: ExecutionModel::default(),
512            custom_capabilities: std::collections::HashMap::new(),
513        }
514    }
515}
516
517impl Default for ExecutionModel {
518    fn default() -> Self {
519        Self {
520            supports_simd: false,
521            supports_simt: false,
522            supports_task_parallelism: true,
523            supports_data_parallelism: true,
524            max_concurrent_streams: Some(1),
525            supports_out_of_order: false,
526        }
527    }
528}
529
530impl Default for BackendCapabilities {
531    fn default() -> Self {
532        Self {
533            max_buffer_size: 1024 * 1024 * 1024, // 1GB
534            max_compute_units: 1,
535            max_workgroup_size: (256, 1, 1),
536            supported_dtypes: vec![DType::F32, DType::F64, DType::I32, DType::I64],
537            supports_async: false,
538            supports_unified_memory: false,
539            supports_sub_buffers: false,
540            supports_kernel_caching: false,
541            memory_bandwidth_gbps: 10.0,
542            compute_throughput_gflops: 100.0,
543            extended_capabilities: ExtendedCapabilities::default(),
544        }
545    }
546}
547
548/// Performance optimization hints
549#[derive(Debug, Clone)]
550pub struct PerformanceHints {
551    /// Preferred workgroup size for compute kernels
552    pub preferred_workgroup_size: (u32, u32, u32),
553
554    /// Optimal memory alignment in bytes
555    pub memory_alignment: usize,
556
557    /// Whether to prefer vectorized operations
558    pub prefer_vectorized: bool,
559
560    /// Whether to use asynchronous operations when possible
561    pub prefer_async: bool,
562
563    /// Optimal batch size for operations
564    pub optimal_batch_size: usize,
565
566    /// Whether to cache compiled kernels
567    pub cache_kernels: bool,
568}
569
570impl Default for PerformanceHints {
571    fn default() -> Self {
572        Self {
573            preferred_workgroup_size: (64, 1, 1),
574            memory_alignment: 16,
575            prefer_vectorized: true,
576            prefer_async: false,
577            optimal_batch_size: 32,
578            cache_kernels: true,
579        }
580    }
581}
582
583/// Backend type enumeration
584#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
585#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
586pub enum BackendType {
587    /// Automatic backend selection
588    Auto,
589    /// CPU backend
590    Cpu,
591    /// CUDA GPU backend
592    Cuda,
593    /// Metal GPU backend
594    Metal,
595    /// ROCm/HIP GPU backend
596    Rocm,
597    /// WebGPU backend
598    WebGpu,
599}
600
601impl std::fmt::Display for BackendType {
602    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603        match self {
604            BackendType::Auto => write!(f, "Auto"),
605            BackendType::Cpu => write!(f, "CPU"),
606            BackendType::Cuda => write!(f, "CUDA"),
607            BackendType::Metal => write!(f, "Metal"),
608            BackendType::Rocm => write!(f, "ROCm"),
609            BackendType::WebGpu => write!(f, "WebGPU"),
610        }
611    }
612}
613
614/// Backend operations trait - groups of related operations
615pub trait BackendOps: Send + Sync {
616    /// Get the backend type
617    fn backend_type(&self) -> BackendType;
618
619    /// Get available operations for this backend
620    fn available_ops(&self) -> Vec<&str>;
621
622    /// Check if an operation is supported
623    fn supports_op(&self, op_name: &str) -> bool;
624
625    /// Check if FFT operations are supported
626    fn supports_fft(&self) -> bool;
627
628    /// Check if convolution operations are supported
629    fn supports_convolution(&self) -> bool;
630
631    /// Check if RNN operations are supported
632    fn supports_rnn(&self) -> bool;
633
634    /// Check if sparse operations are supported
635    fn supports_sparse(&self) -> bool;
636
637    /// Check if quantization operations are supported
638    fn supports_quantization(&self) -> bool;
639
640    /// Get operation-specific capabilities
641    fn operation_capabilities(
642        &self,
643        op_name: &str,
644    ) -> Option<std::collections::HashMap<String, CapabilityValue>>;
645}
646
647/// Backend extensibility trait for adding custom functionality
648pub trait BackendExtension: Send + Sync {
649    /// Get the extension name
650    fn extension_name(&self) -> &str;
651
652    /// Get the extension version
653    fn extension_version(&self) -> &str;
654
655    /// Check if this extension is compatible with the backend
656    fn is_compatible_with(&self, backend: &dyn BackendCore) -> bool;
657
658    /// Initialize the extension with the backend
659    fn initialize(&mut self, backend: &dyn Backend) -> BackendResult<()>;
660
661    /// Shutdown the extension
662    fn shutdown(&mut self) -> BackendResult<()>;
663
664    /// Get extension-specific capabilities
665    fn capabilities(&self) -> std::collections::HashMap<String, CapabilityValue>;
666
667    /// Handle custom operations
668    fn handle_operation(
669        &self,
670        op_name: &str,
671        args: &[CapabilityValue],
672    ) -> BackendResult<CapabilityValue>;
673}
674
675/// Backend registry for managing extensions with enhanced lifetime safety
676pub struct BackendExtensionRegistry {
677    extensions: std::collections::HashMap<String, Box<dyn BackendExtension>>,
678    /// Track initialization state for proper cleanup
679    initialized_extensions: std::collections::HashSet<String>,
680}
681
682impl BackendExtensionRegistry {
683    /// Create a new extension registry
684    pub fn new() -> Self {
685        Self {
686            extensions: std::collections::HashMap::new(),
687            initialized_extensions: std::collections::HashSet::new(),
688        }
689    }
690
691    /// Register a new extension with ownership transfer
692    pub fn register_extension(
693        &mut self,
694        extension: Box<dyn BackendExtension>,
695    ) -> BackendResult<()> {
696        let name = extension.extension_name().to_string();
697        if self.extensions.contains_key(&name) {
698            return Err(TorshError::BackendError(format!(
699                "Extension '{}' is already registered",
700                name
701            )));
702        }
703        self.extensions.insert(name, extension);
704        Ok(())
705    }
706
707    /// Get an extension by name with proper lifetime bounds
708    pub fn get_extension(&self, name: &str) -> Option<&dyn BackendExtension> {
709        self.extensions.get(name).map(|e| e.as_ref())
710    }
711
712    /// Get a mutable extension by name with proper lifetime bounds
713    pub fn get_extension_mut(&mut self, name: &str) -> Option<&mut Box<dyn BackendExtension>> {
714        self.extensions.get_mut(name)
715    }
716
717    /// Get all registered extensions
718    pub fn extensions(&self) -> Vec<&str> {
719        self.extensions.keys().map(|s| s.as_str()).collect()
720    }
721
722    /// Initialize all extensions with a backend with proper error handling
723    pub fn initialize_all(&mut self, backend: &dyn Backend) -> BackendResult<Vec<String>> {
724        let mut failed_extensions = Vec::new();
725
726        for (name, extension) in self.extensions.iter_mut() {
727            if extension.is_compatible_with(backend.as_core()) {
728                match extension.initialize(backend) {
729                    Ok(()) => {
730                        self.initialized_extensions.insert(name.clone());
731                    }
732                    Err(e) => {
733                        failed_extensions.push(format!("{}: {}", name, e));
734                    }
735                }
736            }
737        }
738
739        if failed_extensions.is_empty() {
740            Ok(vec![])
741        } else {
742            Err(TorshError::BackendError(format!(
743                "Failed to initialize extensions: {}",
744                failed_extensions.join(", ")
745            )))
746        }
747    }
748
749    /// Shutdown all extensions with proper error handling
750    pub fn shutdown_all(&mut self) -> BackendResult<Vec<String>> {
751        let mut failed_extensions = Vec::new();
752
753        // Only shutdown initialized extensions
754        for (name, extension) in self.extensions.iter_mut() {
755            if self.initialized_extensions.contains(name) {
756                if let Err(e) = extension.shutdown() {
757                    failed_extensions.push(format!("{}: {}", name, e));
758                } else {
759                    self.initialized_extensions.remove(name);
760                }
761            }
762        }
763
764        if failed_extensions.is_empty() {
765            Ok(vec![])
766        } else {
767            Err(TorshError::BackendError(format!(
768                "Failed to shutdown extensions: {}",
769                failed_extensions.join(", ")
770            )))
771        }
772    }
773
774    /// Remove an extension by name with proper cleanup
775    pub fn remove_extension(&mut self, name: &str) -> Option<Box<dyn BackendExtension>> {
776        // Ensure extension is shutdown before removal
777        if let Some(extension) = self.extensions.get_mut(name) {
778            if self.initialized_extensions.contains(name) {
779                let _ = extension.shutdown(); // Ignore errors during removal
780                self.initialized_extensions.remove(name);
781            }
782        }
783        self.extensions.remove(name)
784    }
785
786    /// Check if an extension is registered
787    pub fn has_extension(&self, name: &str) -> bool {
788        self.extensions.contains_key(name)
789    }
790
791    /// Get the number of registered extensions
792    pub fn len(&self) -> usize {
793        self.extensions.len()
794    }
795
796    /// Check if the registry is empty
797    pub fn is_empty(&self) -> bool {
798        self.extensions.is_empty()
799    }
800}
801
802impl Default for BackendExtensionRegistry {
803    fn default() -> Self {
804        Self::new()
805    }
806}
807
808/// Trait for backend factories
809pub trait BackendFactory: Send + Sync {
810    /// Create a new backend instance
811    fn create(&self) -> BackendResult<Box<dyn Backend>>;
812
813    /// Get the device type this factory creates backends for
814    fn device_type(&self) -> DeviceType;
815
816    /// Check if this backend type is available
817    fn is_available(&self) -> bool;
818
819    /// Get the priority of this backend (higher is better)
820    fn priority(&self) -> u32;
821
822    /// Get the capabilities of backends created by this factory
823    fn capabilities(&self) -> BackendCapabilities;
824}
825
826/// Device enumeration and selection utilities
827pub struct DeviceEnumerator;
828
829impl DeviceEnumerator {
830    /// Enumerate all available devices across all backends
831    pub fn enumerate_all_devices() -> BackendResult<Vec<(DeviceType, Vec<Device>)>> {
832        let mut all_devices = Vec::new();
833
834        // CPU devices are always available
835        #[cfg(feature = "cpu")]
836        {
837            if let Ok(cpu_backend) = crate::cpu::CpuBackend::new() {
838                if let Ok(devices) = cpu_backend.devices() {
839                    all_devices.push((DeviceType::Cpu, devices));
840                }
841            }
842        }
843
844        // CUDA devices
845        #[cfg(feature = "cuda")]
846        {
847            // Since this is a fallback CUDA implementation, no devices are available
848            // This avoids type mismatch issues with the fallback implementation
849        }
850
851        // Metal devices
852        #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
853        {
854            if let Ok(metal_backend) = crate::metal::MetalBackend::new() {
855                if let Ok(devices) = metal_backend.devices() {
856                    all_devices.push((DeviceType::Metal(0), devices));
857                }
858            }
859        }
860
861        // WebGPU devices
862        #[cfg(feature = "webgpu")]
863        {
864            let webgpu_backend = crate::webgpu::WebGpuBackend::with_default_config();
865            if let Ok(devices) = webgpu_backend.devices() {
866                all_devices.push((DeviceType::Wgpu(0), devices));
867            }
868        }
869
870        Ok(all_devices)
871    }
872
873    /// Find the best available device based on performance characteristics
874    pub fn find_best_device() -> BackendResult<(DeviceType, Device)> {
875        let all_devices = Self::enumerate_all_devices()?;
876
877        if all_devices.is_empty() {
878            return Err(TorshError::BackendError("No devices available".to_string()));
879        }
880
881        // Prioritize backends: CUDA > Metal > ROCm > WebGPU > CPU
882        let backend_priorities = [
883            DeviceType::Cuda(0),
884            DeviceType::Metal(0),
885            DeviceType::Wgpu(0),
886            DeviceType::Cpu,
887        ];
888
889        for preferred_type in &backend_priorities {
890            for (device_type, devices) in &all_devices {
891                if Self::device_types_match(device_type, preferred_type) && !devices.is_empty() {
892                    // Find the device with highest compute performance
893                    let best_device = devices
894                        .iter()
895                        .max_by(|a, b| {
896                            a.info()
897                                .peak_gflops
898                                .partial_cmp(&b.info().peak_gflops)
899                                .unwrap_or(std::cmp::Ordering::Equal)
900                        })
901                        .cloned()
902                        .expect("devices should not be empty after is_empty check");
903
904                    return Ok((*device_type, best_device));
905                }
906            }
907        }
908
909        // Fallback to first available device
910        let (device_type, devices) = &all_devices[0];
911        if !devices.is_empty() {
912            Ok((*device_type, devices[0].clone()))
913        } else {
914            Err(TorshError::BackendError(
915                "No usable devices found".to_string(),
916            ))
917        }
918    }
919
920    /// Helper to match device types ignoring IDs
921    fn device_types_match(a: &DeviceType, b: &DeviceType) -> bool {
922        matches!(
923            (a, b),
924            (DeviceType::Cpu, DeviceType::Cpu)
925                | (DeviceType::Cuda(_), DeviceType::Cuda(_))
926                | (DeviceType::Metal(_), DeviceType::Metal(_))
927                | (DeviceType::Wgpu(_), DeviceType::Wgpu(_))
928        )
929    }
930
931    /// Get devices by type
932    pub fn get_devices_by_type(device_type: DeviceType) -> BackendResult<Vec<Device>> {
933        match device_type {
934            #[cfg(feature = "cpu")]
935            DeviceType::Cpu => {
936                let cpu_backend = crate::cpu::CpuBackend::new()?;
937                cpu_backend.devices()
938            }
939            #[cfg(feature = "cuda")]
940            DeviceType::Cuda(_device_id) => {
941                // Since this is a fallback CUDA implementation, return empty vector
942                Ok(vec![])
943            }
944            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
945            DeviceType::Metal(_) => {
946                let metal_backend = crate::metal::MetalBackend::new()?;
947                metal_backend.devices()
948            }
949            #[cfg(feature = "webgpu")]
950            DeviceType::Wgpu(_) => {
951                let webgpu_backend = crate::webgpu::WebGpuBackend::with_default_config();
952                webgpu_backend.devices()
953            }
954            #[allow(unreachable_patterns)]
955            _ => Err(TorshError::BackendError(format!(
956                "Backend type {device_type:?} not available"
957            ))),
958        }
959    }
960
961    /// Check if a specific device type is available
962    pub fn is_device_type_available(device_type: DeviceType) -> bool {
963        match device_type {
964            #[cfg(feature = "cpu")]
965            DeviceType::Cpu => true,
966            #[cfg(cuda_available)]
967            DeviceType::Cuda(device_id) => {
968                crate::cuda::CudaBackend::new(crate::cuda::CudaBackendConfig {
969                    device_id: device_id as usize,
970                    ..Default::default()
971                })
972                .is_ok()
973            }
974            #[cfg(all(feature = "cuda", not(cuda_available)))]
975            DeviceType::Cuda(_) => false, // CUDA feature enabled but not available on this platform
976            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
977            DeviceType::Metal(_) => crate::metal::MetalBackend::new().is_ok(),
978            #[cfg(feature = "webgpu")]
979            DeviceType::Wgpu(_) => true, // WebGPU backend with default config is always available
980            #[allow(unreachable_patterns)]
981            _ => false,
982        }
983    }
984}
985
986/// Backend plugin system for dynamic backend loading
987pub trait BackendPlugin: Send + Sync + std::fmt::Debug {
988    /// Get the name of this plugin
989    fn name(&self) -> &str;
990
991    /// Get the version of this plugin
992    fn version(&self) -> &str;
993
994    /// Create a backend instance from this plugin
995    fn create_backend(&self) -> BackendResult<Box<dyn Backend>>;
996
997    /// Check if this plugin is compatible with the current system
998    fn is_compatible(&self) -> bool;
999
1000    /// Get the device types this plugin supports
1001    fn supported_device_types(&self) -> Vec<DeviceType>;
1002
1003    /// Get plugin metadata
1004    fn metadata(&self) -> PluginMetadata;
1005}
1006
1007/// Plugin metadata information
1008#[derive(Debug, Clone)]
1009pub struct PluginMetadata {
1010    pub name: String,
1011    pub version: String,
1012    pub description: String,
1013    pub author: String,
1014    pub license: String,
1015    pub supported_architectures: Vec<String>,
1016    pub required_features: Vec<String>,
1017    pub optional_features: Vec<String>,
1018}
1019
1020/// Backend resource monitoring and management trait for better RAII patterns
1021pub trait BackendResourceMonitor: Send + Sync {
1022    /// Get the current resource usage
1023    fn resource_usage(&self) -> ResourceUsage;
1024
1025    /// Set resource limits
1026    fn set_resource_limits(&mut self, limits: ResourceLimits) -> BackendResult<()>;
1027
1028    /// Get resource limits
1029    fn resource_limits(&self) -> ResourceLimits;
1030
1031    /// Cleanup unused resources
1032    fn cleanup_resources(&mut self) -> BackendResult<()>;
1033
1034    /// Get resource statistics
1035    fn resource_statistics(&self) -> ResourceStatistics;
1036
1037    /// Enable resource monitoring
1038    fn enable_monitoring(&mut self) -> BackendResult<()>;
1039
1040    /// Disable resource monitoring
1041    fn disable_monitoring(&mut self) -> BackendResult<()>;
1042
1043    /// Check if monitoring is enabled
1044    fn is_monitoring_enabled(&self) -> bool;
1045}
1046
1047/// Current resource usage information
1048#[derive(Debug, Clone)]
1049pub struct ResourceUsage {
1050    pub memory_used: usize,
1051    pub buffers_allocated: usize,
1052    pub kernels_cached: usize,
1053    pub active_streams: usize,
1054    pub cpu_usage_percent: f32,
1055    pub gpu_usage_percent: f32,
1056}
1057
1058/// Resource limits configuration
1059#[derive(Debug, Clone)]
1060pub struct ResourceLimits {
1061    pub max_memory: Option<usize>,
1062    pub max_buffers: Option<usize>,
1063    pub max_kernels: Option<usize>,
1064    pub max_streams: Option<usize>,
1065    pub memory_pressure_threshold: f32,
1066}
1067
1068/// Resource statistics over time
1069#[derive(Debug, Clone)]
1070pub struct ResourceStatistics {
1071    pub peak_memory_usage: usize,
1072    pub total_allocations: u64,
1073    pub total_deallocations: u64,
1074    pub average_buffer_size: f32,
1075    pub cache_hit_rate: f32,
1076    pub allocation_failure_count: u32,
1077}
1078
1079/// Backend registry for managing multiple backends and plugins
1080pub struct BackendRegistry {
1081    backends: std::collections::HashMap<String, Box<dyn BackendPlugin>>,
1082    default_backend: Option<String>,
1083}
1084
1085impl BackendRegistry {
1086    /// Create a new backend registry
1087    pub fn new() -> Self {
1088        Self {
1089            backends: std::collections::HashMap::new(),
1090            default_backend: None,
1091        }
1092    }
1093
1094    /// Register a new backend plugin
1095    pub fn register_plugin(&mut self, plugin: Box<dyn BackendPlugin>) -> BackendResult<()> {
1096        let name = plugin.name().to_string();
1097
1098        // Check if plugin is compatible
1099        if !plugin.is_compatible() {
1100            return Err(TorshError::BackendError(format!(
1101                "Plugin {name} is not compatible with current system"
1102            )));
1103        }
1104
1105        self.backends.insert(name.clone(), plugin);
1106
1107        // Set as default if this is the first compatible plugin
1108        if self.default_backend.is_none() {
1109            self.default_backend = Some(name);
1110        }
1111
1112        Ok(())
1113    }
1114
1115    /// Get available backend names
1116    pub fn available_backends(&self) -> Vec<String> {
1117        self.backends.keys().cloned().collect()
1118    }
1119
1120    /// Create a backend by name
1121    pub fn create_backend(&self, name: &str) -> BackendResult<Box<dyn Backend>> {
1122        if let Some(plugin) = self.backends.get(name) {
1123            plugin.create_backend()
1124        } else {
1125            Err(TorshError::BackendError(format!(
1126                "Backend {name} not found"
1127            )))
1128        }
1129    }
1130
1131    /// Create the default backend
1132    pub fn create_default_backend(&self) -> BackendResult<Box<dyn Backend>> {
1133        if let Some(default_name) = &self.default_backend {
1134            self.create_backend(default_name)
1135        } else {
1136            Err(TorshError::BackendError(
1137                "No default backend available".to_string(),
1138            ))
1139        }
1140    }
1141
1142    /// Set the default backend
1143    pub fn set_default_backend(&mut self, name: &str) -> BackendResult<()> {
1144        if self.backends.contains_key(name) {
1145            self.default_backend = Some(name.to_string());
1146            Ok(())
1147        } else {
1148            Err(TorshError::BackendError(format!(
1149                "Backend {name} not found"
1150            )))
1151        }
1152    }
1153
1154    /// Get plugin metadata
1155    pub fn get_plugin_metadata(&self, name: &str) -> Option<PluginMetadata> {
1156        self.backends.get(name).map(|plugin| plugin.metadata())
1157    }
1158}
1159
1160impl Default for BackendRegistry {
1161    fn default() -> Self {
1162        Self::new()
1163    }
1164}
1165
1166/// Backend configuration trait for customizing backend behavior
1167pub trait BackendConfig: Send + Sync + Clone {
1168    /// Get the backend type this configuration is for
1169    fn backend_type(&self) -> BackendType;
1170
1171    /// Validate the configuration
1172    fn validate(&self) -> BackendResult<()>;
1173
1174    /// Get configuration as key-value pairs
1175    fn as_properties(&self) -> std::collections::HashMap<String, CapabilityValue>;
1176
1177    /// Set configuration from key-value pairs
1178    fn from_properties(
1179        properties: &std::collections::HashMap<String, CapabilityValue>,
1180    ) -> BackendResult<Self>
1181    where
1182        Self: Sized;
1183
1184    /// Merge with another configuration
1185    fn merge(&mut self, other: &Self) -> BackendResult<()>;
1186
1187    /// Get default configuration
1188    fn default_config() -> Self
1189    where
1190        Self: Sized;
1191}
1192
1193/// Backend builder trait for creating configured backends
1194pub trait BackendBuilder<T: BackendConfig>: Send + Sync {
1195    /// Create a new builder with default configuration
1196    fn new() -> Self;
1197
1198    /// Set configuration
1199    fn with_config(self, config: T) -> Self;
1200
1201    /// Build the backend
1202    fn build(self) -> BackendResult<Box<dyn Backend>>;
1203
1204    /// Get the current configuration
1205    fn config(&self) -> &T;
1206
1207    /// Get a mutable reference to the configuration
1208    fn config_mut(&mut self) -> &mut T;
1209}
1210
1211/// Backend error handling trait for better error context
1212pub trait BackendErrorHandler: Send + Sync {
1213    /// Handle a backend error and provide context
1214    fn handle_error(&self, error: TorshError, context: &str) -> TorshError;
1215
1216    /// Convert a backend-specific error to TorshError
1217    fn convert_error(&self, error: Box<dyn std::error::Error + Send + Sync>) -> TorshError;
1218
1219    /// Get error recovery suggestions
1220    fn recovery_suggestions(&self, error: &TorshError) -> Vec<String>;
1221
1222    /// Log error with appropriate level
1223    fn log_error(&self, error: &TorshError, context: &str);
1224}
1225
1226/// Default error handler implementation
1227pub struct DefaultBackendErrorHandler {
1228    backend_name: String,
1229}
1230
1231impl DefaultBackendErrorHandler {
1232    pub fn new(backend_name: String) -> Self {
1233        Self { backend_name }
1234    }
1235}
1236
1237impl BackendErrorHandler for DefaultBackendErrorHandler {
1238    fn handle_error(&self, error: TorshError, context: &str) -> TorshError {
1239        // Add backend context to error
1240        match error {
1241            TorshError::BackendError(msg) => TorshError::BackendError(format!(
1242                "{}: {} (context: {})",
1243                self.backend_name, msg, context
1244            )),
1245            other => other,
1246        }
1247    }
1248
1249    fn convert_error(&self, error: Box<dyn std::error::Error + Send + Sync>) -> TorshError {
1250        TorshError::BackendError(format!("{}: {}", self.backend_name, error))
1251    }
1252
1253    fn recovery_suggestions(&self, error: &TorshError) -> Vec<String> {
1254        match error {
1255            TorshError::BackendError(msg) if msg.contains("not available") => {
1256                vec![
1257                    "Check if the backend is properly installed".to_string(),
1258                    "Verify system compatibility".to_string(),
1259                    "Try a different backend".to_string(),
1260                ]
1261            }
1262            TorshError::BackendError(msg) if msg.contains("memory") => {
1263                vec![
1264                    "Reduce batch size or tensor dimensions".to_string(),
1265                    "Enable memory optimization".to_string(),
1266                    "Check available memory".to_string(),
1267                ]
1268            }
1269            _ => vec!["Contact support with error details".to_string()],
1270        }
1271    }
1272
1273    fn log_error(&self, error: &TorshError, context: &str) {
1274        eprintln!("[{}] Error in {}: {}", self.backend_name, context, error);
1275    }
1276}
1277
1278/// Extended Backend trait with device factory methods
1279impl dyn Backend {
1280    /// Create the best available backend automatically
1281    pub fn auto() -> BackendResult<Box<dyn Backend>> {
1282        let (device_type, _device) = DeviceEnumerator::find_best_device()?;
1283
1284        match device_type {
1285            #[cfg(feature = "cpu")]
1286            DeviceType::Cpu => Ok(Box::new(crate::cpu::CpuBackend::new()?)),
1287            #[cfg(cuda_available)]
1288            DeviceType::Cuda(device_id) => Ok(Box::new(crate::cuda::CudaBackend::new(
1289                crate::cuda::CudaBackendConfig {
1290                    device_id: device_id as usize,
1291                    ..Default::default()
1292                },
1293            )?)),
1294            #[cfg(all(feature = "cuda", not(cuda_available)))]
1295            DeviceType::Cuda(_) => Err(TorshError::BackendError(
1296                "CUDA backend not available on this platform".to_string(),
1297            )),
1298            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
1299            DeviceType::Metal(_) => Ok(Box::new(crate::metal::MetalBackend::new()?)),
1300            #[cfg(feature = "webgpu")]
1301            DeviceType::Wgpu(_) => {
1302                Ok(Box::new(crate::webgpu::WebGpuBackend::with_default_config()))
1303            }
1304            #[allow(unreachable_patterns)]
1305            _ => Err(TorshError::BackendError(
1306                "No suitable backend found".to_string(),
1307            )),
1308        }
1309    }
1310}