1use crate::memory::MemoryManager;
4use crate::profiler::Profiler;
5use crate::{Buffer, BufferDescriptor, Device, Kernel, KernelDescriptor};
6use torsh_core::{device::DeviceType, dtype::DType, error::TorshError};
7
8#[cfg(not(feature = "std"))]
9use alloc::{boxed::Box, string::String, vec::Vec};
10
11pub type BackendResult<T> = Result<T, TorshError>;
13
14pub trait BackendCore: Send + Sync + std::fmt::Debug {
18 fn device_type(&self) -> DeviceType;
20
21 fn name(&self) -> &str;
23
24 fn is_available(&self) -> BackendResult<bool>;
26
27 fn capabilities(&self) -> BackendCapabilities;
29
30 fn performance_hints(&self) -> PerformanceHints;
32}
33
34#[async_trait::async_trait]
36pub trait BackendLifecycle: Send + Sync {
37 async fn initialize(&mut self) -> BackendResult<()>;
39
40 async fn shutdown(&mut self) -> BackendResult<()>;
42
43 fn is_initialized(&self) -> bool;
45}
46
47pub trait BackendDeviceManager: Send + Sync {
49 fn devices(&self) -> BackendResult<Vec<Device>>;
51
52 fn default_device(&self) -> BackendResult<Device>;
54
55 fn create_device(&self, device_id: usize) -> BackendResult<Device>;
57
58 fn device_count(&self) -> BackendResult<usize>;
60
61 fn is_device_available(&self, device_id: usize) -> bool;
63}
64
65pub trait BackendResourceManager: Send + Sync {
67 fn create_buffer(
69 &self,
70 device: &Device,
71 descriptor: &BufferDescriptor,
72 ) -> BackendResult<Buffer>;
73
74 fn create_kernel(
76 &self,
77 device: &Device,
78 descriptor: &KernelDescriptor,
79 ) -> BackendResult<Kernel>;
80
81 fn memory_manager(
83 &self,
84 device: &Device,
85 ) -> BackendResult<Box<dyn MemoryManager + Send + Sync>>;
86
87 fn profiler(&self) -> BackendResult<Box<dyn Profiler + Send + Sync>>;
89
90 fn create_scoped_buffer(
92 &self,
93 device: &Device,
94 descriptor: &BufferDescriptor,
95 ) -> BackendResult<Buffer>;
96}
97
98pub trait BackendAdvancedResourceManager: Send + Sync {
100 fn create_resource_with_cleanup<T, F>(
102 &self,
103 device: &Device,
104 factory: F,
105 cleanup: impl FnOnce(&T) + Send + 'static,
106 ) -> BackendResult<ManagedResource<T>>
107 where
108 T: Send + Sync + 'static,
109 F: FnOnce(&Device) -> BackendResult<T>;
110}
111
112#[async_trait::async_trait]
114pub trait BackendExecutor: Send + Sync {
115 async fn synchronize(&self, device: &Device) -> BackendResult<()>;
117
118 async fn copy_buffer(
120 &self,
121 src: &Buffer,
122 dst: &Buffer,
123 src_offset: usize,
124 dst_offset: usize,
125 size: usize,
126 ) -> BackendResult<()>;
127
128 async fn copy_to_device(
130 &self,
131 src: &[u8],
132 dst: &Buffer,
133 dst_offset: usize,
134 ) -> BackendResult<()>;
135
136 async fn copy_from_device(
138 &self,
139 src: &Buffer,
140 dst: &mut [u8],
141 src_offset: usize,
142 ) -> BackendResult<()>;
143
144 async fn execute_kernel(
146 &self,
147 kernel: &Kernel,
148 buffers: &[&Buffer],
149 uniform_data: &[u8],
150 workgroup_size: (u32, u32, u32),
151 workgroup_count: (u32, u32, u32),
152 ) -> BackendResult<()>;
153}
154
155pub trait BackendOperations: Send + Sync {
157 fn fft_ops(&self) -> Box<dyn crate::fft::FftOps>;
159
160 fn convolution_ops(&self) -> Box<dyn crate::convolution::ConvolutionOps>;
162
163 fn rnn_ops(&self) -> Box<dyn crate::rnn::RnnOps>;
165
166 fn sparse_ops(&self) -> Box<dyn crate::sparse_ops::SparseOps<f32>>;
168
169 fn quantization_ops(&self) -> Box<dyn crate::quantization::QuantizationOps>;
171
172 fn operations_bundle(&self) -> OperationsBundle;
174}
175
176pub trait Backend:
178 BackendCore
179 + BackendLifecycle
180 + BackendDeviceManager
181 + BackendResourceManager
182 + BackendExecutor
183 + BackendOperations
184 + BackendOps
185{
186 fn as_core(&self) -> &dyn BackendCore;
188
189 fn as_lifecycle(&mut self) -> &mut dyn BackendLifecycle;
191
192 fn as_device_manager(&self) -> &dyn BackendDeviceManager;
194
195 fn as_resource_manager(&self) -> &dyn BackendResourceManager;
197
198 fn as_executor(&self) -> &dyn BackendExecutor;
200
201 fn as_operations(&self) -> &dyn BackendOperations;
203}
204
205pub struct ScopedResource<'a, T> {
207 resource: Option<T>,
208 cleanup: Option<Box<dyn FnOnce(T) + Send + 'a>>,
209}
210
211impl<'a, T> ScopedResource<'a, T> {
212 pub fn new(resource: T) -> Self {
214 Self {
215 resource: Some(resource),
216 cleanup: None,
217 }
218 }
219
220 pub fn new_with_cleanup<F>(resource: T, cleanup: F) -> Self
222 where
223 F: FnOnce(T) + Send + 'a,
224 {
225 Self {
226 resource: Some(resource),
227 cleanup: Some(Box::new(cleanup)),
228 }
229 }
230
231 pub fn get(&self) -> Option<&T> {
233 self.resource.as_ref()
234 }
235
236 pub fn get_mut(&mut self) -> Option<&mut T> {
238 self.resource.as_mut()
239 }
240
241 pub fn take(mut self) -> Option<T> {
243 self.cleanup = None; self.resource.take()
245 }
246
247 pub fn with_resource<F, R>(&self, f: F) -> Option<R>
249 where
250 F: FnOnce(&T) -> R,
251 {
252 self.resource.as_ref().map(f)
253 }
254
255 pub fn is_available(&self) -> bool {
257 self.resource.is_some()
258 }
259}
260
261impl<'a, T> Drop for ScopedResource<'a, T> {
262 fn drop(&mut self) {
263 if let (Some(resource), Some(cleanup)) = (self.resource.take(), self.cleanup.take()) {
264 cleanup(resource);
265 }
266 }
267}
268
269pub struct ManagedResource<T> {
271 resource: Option<T>,
272 cleanup: Option<Box<dyn FnOnce(&T) + Send + 'static>>,
273}
274
275impl<T> ManagedResource<T> {
276 pub fn new(resource: T) -> Self {
278 Self {
279 resource: Some(resource),
280 cleanup: None,
281 }
282 }
283
284 pub fn new_with_cleanup<F>(resource: T, cleanup: F) -> Self
286 where
287 F: FnOnce(&T) + Send + 'static,
288 {
289 Self {
290 resource: Some(resource),
291 cleanup: Some(Box::new(cleanup)),
292 }
293 }
294
295 pub fn get(&self) -> Option<&T> {
297 self.resource.as_ref()
298 }
299
300 pub fn get_mut(&mut self) -> Option<&mut T> {
302 self.resource.as_mut()
303 }
304
305 pub fn take(mut self) -> Option<T> {
307 self.cleanup = None; self.resource.take()
309 }
310
311 pub fn with_resource<F, R>(&self, f: F) -> Option<R>
313 where
314 F: FnOnce(&T) -> R,
315 {
316 self.resource.as_ref().map(f)
317 }
318
319 pub fn is_available(&self) -> bool {
321 self.resource.is_some()
322 }
323}
324
325impl<T> Drop for ManagedResource<T> {
326 fn drop(&mut self) {
327 if let (Some(resource), Some(cleanup)) = (self.resource.as_ref(), self.cleanup.take()) {
328 cleanup(resource);
329 }
330 }
331}
332
333unsafe impl<T: Send> Send for ManagedResource<T> {}
335unsafe impl<T: Sync> Sync for ManagedResource<T> {}
336
337pub struct OperationsBundle {
339 pub fft: Box<dyn crate::fft::FftOps>,
340 pub convolution: Box<dyn crate::convolution::ConvolutionOps>,
341 pub rnn: Box<dyn crate::rnn::RnnOps>,
342 pub sparse: Box<dyn crate::sparse_ops::SparseOps<f32>>,
343 pub quantization: Box<dyn crate::quantization::QuantizationOps>,
344}
345
346impl OperationsBundle {
347 pub fn new(
349 fft: Box<dyn crate::fft::FftOps>,
350 convolution: Box<dyn crate::convolution::ConvolutionOps>,
351 rnn: Box<dyn crate::rnn::RnnOps>,
352 sparse: Box<dyn crate::sparse_ops::SparseOps<f32>>,
353 quantization: Box<dyn crate::quantization::QuantizationOps>,
354 ) -> Self {
355 Self {
356 fft,
357 convolution,
358 rnn,
359 sparse,
360 quantization,
361 }
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct BackendCapabilities {
368 pub max_buffer_size: usize,
370
371 pub max_compute_units: usize,
373
374 pub max_workgroup_size: (u32, u32, u32),
376
377 pub supported_dtypes: Vec<DType>,
379
380 pub supports_async: bool,
382
383 pub supports_unified_memory: bool,
385
386 pub supports_sub_buffers: bool,
388
389 pub supports_kernel_caching: bool,
391
392 pub memory_bandwidth_gbps: f32,
394
395 pub compute_throughput_gflops: f32,
397
398 pub extended_capabilities: ExtendedCapabilities,
400}
401
402#[derive(Debug, Clone)]
404pub struct ExtendedCapabilities {
405 pub max_tensor_dims: Option<usize>,
407
408 pub precision_modes: Vec<PrecisionMode>,
410
411 pub hardware_features: Vec<HardwareFeature>,
413
414 pub memory_hierarchy: MemoryHierarchy,
416
417 pub execution_model: ExecutionModel,
419
420 pub custom_capabilities: std::collections::HashMap<String, CapabilityValue>,
422}
423
424#[derive(Debug, Clone, PartialEq)]
426pub enum PrecisionMode {
427 F16,
429 F32,
431 F64,
433 Mixed,
435 Custom(u8),
437}
438
439#[derive(Debug, Clone, PartialEq)]
441pub enum HardwareFeature {
442 TensorCores,
444 VectorUnits,
446 SharedMemory,
448 ConstantMemory,
450 AtomicOperations,
452 CooperativeGroups,
454 DynamicParallelism,
456 Custom(String),
458}
459
460#[derive(Debug, Clone, Default)]
462pub struct MemoryHierarchy {
463 pub l1_cache_size: Option<usize>,
465 pub l2_cache_size: Option<usize>,
467 pub l3_cache_size: Option<usize>,
469 pub shared_memory_size: Option<usize>,
471 pub memory_latency_cycles: Option<u32>,
473 pub memory_bandwidth_per_core: Option<f32>,
475}
476
477#[derive(Debug, Clone)]
479pub struct ExecutionModel {
480 pub supports_simd: bool,
482 pub supports_simt: bool,
484 pub supports_task_parallelism: bool,
486 pub supports_data_parallelism: bool,
488 pub max_concurrent_streams: Option<u32>,
490 pub supports_out_of_order: bool,
492}
493
494#[derive(Debug, Clone)]
496pub enum CapabilityValue {
497 Bool(bool),
498 Int(i64),
499 Float(f64),
500 String(String),
501 List(Vec<CapabilityValue>),
502}
503
504impl Default for ExtendedCapabilities {
505 fn default() -> Self {
506 Self {
507 max_tensor_dims: Some(8),
508 precision_modes: vec![PrecisionMode::F32],
509 hardware_features: vec![],
510 memory_hierarchy: MemoryHierarchy::default(),
511 execution_model: ExecutionModel::default(),
512 custom_capabilities: std::collections::HashMap::new(),
513 }
514 }
515}
516
517impl Default for ExecutionModel {
518 fn default() -> Self {
519 Self {
520 supports_simd: false,
521 supports_simt: false,
522 supports_task_parallelism: true,
523 supports_data_parallelism: true,
524 max_concurrent_streams: Some(1),
525 supports_out_of_order: false,
526 }
527 }
528}
529
530impl Default for BackendCapabilities {
531 fn default() -> Self {
532 Self {
533 max_buffer_size: 1024 * 1024 * 1024, max_compute_units: 1,
535 max_workgroup_size: (256, 1, 1),
536 supported_dtypes: vec![DType::F32, DType::F64, DType::I32, DType::I64],
537 supports_async: false,
538 supports_unified_memory: false,
539 supports_sub_buffers: false,
540 supports_kernel_caching: false,
541 memory_bandwidth_gbps: 10.0,
542 compute_throughput_gflops: 100.0,
543 extended_capabilities: ExtendedCapabilities::default(),
544 }
545 }
546}
547
548#[derive(Debug, Clone)]
550pub struct PerformanceHints {
551 pub preferred_workgroup_size: (u32, u32, u32),
553
554 pub memory_alignment: usize,
556
557 pub prefer_vectorized: bool,
559
560 pub prefer_async: bool,
562
563 pub optimal_batch_size: usize,
565
566 pub cache_kernels: bool,
568}
569
570impl Default for PerformanceHints {
571 fn default() -> Self {
572 Self {
573 preferred_workgroup_size: (64, 1, 1),
574 memory_alignment: 16,
575 prefer_vectorized: true,
576 prefer_async: false,
577 optimal_batch_size: 32,
578 cache_kernels: true,
579 }
580 }
581}
582
583#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
585#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
586pub enum BackendType {
587 Auto,
589 Cpu,
591 Cuda,
593 Metal,
595 Rocm,
597 WebGpu,
599}
600
601impl std::fmt::Display for BackendType {
602 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
603 match self {
604 BackendType::Auto => write!(f, "Auto"),
605 BackendType::Cpu => write!(f, "CPU"),
606 BackendType::Cuda => write!(f, "CUDA"),
607 BackendType::Metal => write!(f, "Metal"),
608 BackendType::Rocm => write!(f, "ROCm"),
609 BackendType::WebGpu => write!(f, "WebGPU"),
610 }
611 }
612}
613
614pub trait BackendOps: Send + Sync {
616 fn backend_type(&self) -> BackendType;
618
619 fn available_ops(&self) -> Vec<&str>;
621
622 fn supports_op(&self, op_name: &str) -> bool;
624
625 fn supports_fft(&self) -> bool;
627
628 fn supports_convolution(&self) -> bool;
630
631 fn supports_rnn(&self) -> bool;
633
634 fn supports_sparse(&self) -> bool;
636
637 fn supports_quantization(&self) -> bool;
639
640 fn operation_capabilities(
642 &self,
643 op_name: &str,
644 ) -> Option<std::collections::HashMap<String, CapabilityValue>>;
645}
646
647pub trait BackendExtension: Send + Sync {
649 fn extension_name(&self) -> &str;
651
652 fn extension_version(&self) -> &str;
654
655 fn is_compatible_with(&self, backend: &dyn BackendCore) -> bool;
657
658 fn initialize(&mut self, backend: &dyn Backend) -> BackendResult<()>;
660
661 fn shutdown(&mut self) -> BackendResult<()>;
663
664 fn capabilities(&self) -> std::collections::HashMap<String, CapabilityValue>;
666
667 fn handle_operation(
669 &self,
670 op_name: &str,
671 args: &[CapabilityValue],
672 ) -> BackendResult<CapabilityValue>;
673}
674
675pub struct BackendExtensionRegistry {
677 extensions: std::collections::HashMap<String, Box<dyn BackendExtension>>,
678 initialized_extensions: std::collections::HashSet<String>,
680}
681
682impl BackendExtensionRegistry {
683 pub fn new() -> Self {
685 Self {
686 extensions: std::collections::HashMap::new(),
687 initialized_extensions: std::collections::HashSet::new(),
688 }
689 }
690
691 pub fn register_extension(
693 &mut self,
694 extension: Box<dyn BackendExtension>,
695 ) -> BackendResult<()> {
696 let name = extension.extension_name().to_string();
697 if self.extensions.contains_key(&name) {
698 return Err(TorshError::BackendError(format!(
699 "Extension '{}' is already registered",
700 name
701 )));
702 }
703 self.extensions.insert(name, extension);
704 Ok(())
705 }
706
707 pub fn get_extension(&self, name: &str) -> Option<&dyn BackendExtension> {
709 self.extensions.get(name).map(|e| e.as_ref())
710 }
711
712 pub fn get_extension_mut(&mut self, name: &str) -> Option<&mut Box<dyn BackendExtension>> {
714 self.extensions.get_mut(name)
715 }
716
717 pub fn extensions(&self) -> Vec<&str> {
719 self.extensions.keys().map(|s| s.as_str()).collect()
720 }
721
722 pub fn initialize_all(&mut self, backend: &dyn Backend) -> BackendResult<Vec<String>> {
724 let mut failed_extensions = Vec::new();
725
726 for (name, extension) in self.extensions.iter_mut() {
727 if extension.is_compatible_with(backend.as_core()) {
728 match extension.initialize(backend) {
729 Ok(()) => {
730 self.initialized_extensions.insert(name.clone());
731 }
732 Err(e) => {
733 failed_extensions.push(format!("{}: {}", name, e));
734 }
735 }
736 }
737 }
738
739 if failed_extensions.is_empty() {
740 Ok(vec![])
741 } else {
742 Err(TorshError::BackendError(format!(
743 "Failed to initialize extensions: {}",
744 failed_extensions.join(", ")
745 )))
746 }
747 }
748
749 pub fn shutdown_all(&mut self) -> BackendResult<Vec<String>> {
751 let mut failed_extensions = Vec::new();
752
753 for (name, extension) in self.extensions.iter_mut() {
755 if self.initialized_extensions.contains(name) {
756 if let Err(e) = extension.shutdown() {
757 failed_extensions.push(format!("{}: {}", name, e));
758 } else {
759 self.initialized_extensions.remove(name);
760 }
761 }
762 }
763
764 if failed_extensions.is_empty() {
765 Ok(vec![])
766 } else {
767 Err(TorshError::BackendError(format!(
768 "Failed to shutdown extensions: {}",
769 failed_extensions.join(", ")
770 )))
771 }
772 }
773
774 pub fn remove_extension(&mut self, name: &str) -> Option<Box<dyn BackendExtension>> {
776 if let Some(extension) = self.extensions.get_mut(name) {
778 if self.initialized_extensions.contains(name) {
779 let _ = extension.shutdown(); self.initialized_extensions.remove(name);
781 }
782 }
783 self.extensions.remove(name)
784 }
785
786 pub fn has_extension(&self, name: &str) -> bool {
788 self.extensions.contains_key(name)
789 }
790
791 pub fn len(&self) -> usize {
793 self.extensions.len()
794 }
795
796 pub fn is_empty(&self) -> bool {
798 self.extensions.is_empty()
799 }
800}
801
802impl Default for BackendExtensionRegistry {
803 fn default() -> Self {
804 Self::new()
805 }
806}
807
808pub trait BackendFactory: Send + Sync {
810 fn create(&self) -> BackendResult<Box<dyn Backend>>;
812
813 fn device_type(&self) -> DeviceType;
815
816 fn is_available(&self) -> bool;
818
819 fn priority(&self) -> u32;
821
822 fn capabilities(&self) -> BackendCapabilities;
824}
825
826pub struct DeviceEnumerator;
828
829impl DeviceEnumerator {
830 pub fn enumerate_all_devices() -> BackendResult<Vec<(DeviceType, Vec<Device>)>> {
832 let mut all_devices = Vec::new();
833
834 #[cfg(feature = "cpu")]
836 {
837 if let Ok(cpu_backend) = crate::cpu::CpuBackend::new() {
838 if let Ok(devices) = cpu_backend.devices() {
839 all_devices.push((DeviceType::Cpu, devices));
840 }
841 }
842 }
843
844 #[cfg(feature = "cuda")]
846 {
847 }
850
851 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
853 {
854 if let Ok(metal_backend) = crate::metal::MetalBackend::new() {
855 if let Ok(devices) = metal_backend.devices() {
856 all_devices.push((DeviceType::Metal(0), devices));
857 }
858 }
859 }
860
861 #[cfg(feature = "webgpu")]
863 {
864 let webgpu_backend = crate::webgpu::WebGpuBackend::with_default_config();
865 if let Ok(devices) = webgpu_backend.devices() {
866 all_devices.push((DeviceType::Wgpu(0), devices));
867 }
868 }
869
870 Ok(all_devices)
871 }
872
873 pub fn find_best_device() -> BackendResult<(DeviceType, Device)> {
875 let all_devices = Self::enumerate_all_devices()?;
876
877 if all_devices.is_empty() {
878 return Err(TorshError::BackendError("No devices available".to_string()));
879 }
880
881 let backend_priorities = [
883 DeviceType::Cuda(0),
884 DeviceType::Metal(0),
885 DeviceType::Wgpu(0),
886 DeviceType::Cpu,
887 ];
888
889 for preferred_type in &backend_priorities {
890 for (device_type, devices) in &all_devices {
891 if Self::device_types_match(device_type, preferred_type) && !devices.is_empty() {
892 let best_device = devices
894 .iter()
895 .max_by(|a, b| {
896 a.info()
897 .peak_gflops
898 .partial_cmp(&b.info().peak_gflops)
899 .unwrap_or(std::cmp::Ordering::Equal)
900 })
901 .cloned()
902 .expect("devices should not be empty after is_empty check");
903
904 return Ok((*device_type, best_device));
905 }
906 }
907 }
908
909 let (device_type, devices) = &all_devices[0];
911 if !devices.is_empty() {
912 Ok((*device_type, devices[0].clone()))
913 } else {
914 Err(TorshError::BackendError(
915 "No usable devices found".to_string(),
916 ))
917 }
918 }
919
920 fn device_types_match(a: &DeviceType, b: &DeviceType) -> bool {
922 matches!(
923 (a, b),
924 (DeviceType::Cpu, DeviceType::Cpu)
925 | (DeviceType::Cuda(_), DeviceType::Cuda(_))
926 | (DeviceType::Metal(_), DeviceType::Metal(_))
927 | (DeviceType::Wgpu(_), DeviceType::Wgpu(_))
928 )
929 }
930
931 pub fn get_devices_by_type(device_type: DeviceType) -> BackendResult<Vec<Device>> {
933 match device_type {
934 #[cfg(feature = "cpu")]
935 DeviceType::Cpu => {
936 let cpu_backend = crate::cpu::CpuBackend::new()?;
937 cpu_backend.devices()
938 }
939 #[cfg(feature = "cuda")]
940 DeviceType::Cuda(_device_id) => {
941 Ok(vec![])
943 }
944 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
945 DeviceType::Metal(_) => {
946 let metal_backend = crate::metal::MetalBackend::new()?;
947 metal_backend.devices()
948 }
949 #[cfg(feature = "webgpu")]
950 DeviceType::Wgpu(_) => {
951 let webgpu_backend = crate::webgpu::WebGpuBackend::with_default_config();
952 webgpu_backend.devices()
953 }
954 #[allow(unreachable_patterns)]
955 _ => Err(TorshError::BackendError(format!(
956 "Backend type {device_type:?} not available"
957 ))),
958 }
959 }
960
961 pub fn is_device_type_available(device_type: DeviceType) -> bool {
963 match device_type {
964 #[cfg(feature = "cpu")]
965 DeviceType::Cpu => true,
966 #[cfg(cuda_available)]
967 DeviceType::Cuda(device_id) => {
968 crate::cuda::CudaBackend::new(crate::cuda::CudaBackendConfig {
969 device_id: device_id as usize,
970 ..Default::default()
971 })
972 .is_ok()
973 }
974 #[cfg(all(feature = "cuda", not(cuda_available)))]
975 DeviceType::Cuda(_) => false, #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
977 DeviceType::Metal(_) => crate::metal::MetalBackend::new().is_ok(),
978 #[cfg(feature = "webgpu")]
979 DeviceType::Wgpu(_) => true, #[allow(unreachable_patterns)]
981 _ => false,
982 }
983 }
984}
985
986pub trait BackendPlugin: Send + Sync + std::fmt::Debug {
988 fn name(&self) -> &str;
990
991 fn version(&self) -> &str;
993
994 fn create_backend(&self) -> BackendResult<Box<dyn Backend>>;
996
997 fn is_compatible(&self) -> bool;
999
1000 fn supported_device_types(&self) -> Vec<DeviceType>;
1002
1003 fn metadata(&self) -> PluginMetadata;
1005}
1006
1007#[derive(Debug, Clone)]
1009pub struct PluginMetadata {
1010 pub name: String,
1011 pub version: String,
1012 pub description: String,
1013 pub author: String,
1014 pub license: String,
1015 pub supported_architectures: Vec<String>,
1016 pub required_features: Vec<String>,
1017 pub optional_features: Vec<String>,
1018}
1019
1020pub trait BackendResourceMonitor: Send + Sync {
1022 fn resource_usage(&self) -> ResourceUsage;
1024
1025 fn set_resource_limits(&mut self, limits: ResourceLimits) -> BackendResult<()>;
1027
1028 fn resource_limits(&self) -> ResourceLimits;
1030
1031 fn cleanup_resources(&mut self) -> BackendResult<()>;
1033
1034 fn resource_statistics(&self) -> ResourceStatistics;
1036
1037 fn enable_monitoring(&mut self) -> BackendResult<()>;
1039
1040 fn disable_monitoring(&mut self) -> BackendResult<()>;
1042
1043 fn is_monitoring_enabled(&self) -> bool;
1045}
1046
1047#[derive(Debug, Clone)]
1049pub struct ResourceUsage {
1050 pub memory_used: usize,
1051 pub buffers_allocated: usize,
1052 pub kernels_cached: usize,
1053 pub active_streams: usize,
1054 pub cpu_usage_percent: f32,
1055 pub gpu_usage_percent: f32,
1056}
1057
1058#[derive(Debug, Clone)]
1060pub struct ResourceLimits {
1061 pub max_memory: Option<usize>,
1062 pub max_buffers: Option<usize>,
1063 pub max_kernels: Option<usize>,
1064 pub max_streams: Option<usize>,
1065 pub memory_pressure_threshold: f32,
1066}
1067
1068#[derive(Debug, Clone)]
1070pub struct ResourceStatistics {
1071 pub peak_memory_usage: usize,
1072 pub total_allocations: u64,
1073 pub total_deallocations: u64,
1074 pub average_buffer_size: f32,
1075 pub cache_hit_rate: f32,
1076 pub allocation_failure_count: u32,
1077}
1078
1079pub struct BackendRegistry {
1081 backends: std::collections::HashMap<String, Box<dyn BackendPlugin>>,
1082 default_backend: Option<String>,
1083}
1084
1085impl BackendRegistry {
1086 pub fn new() -> Self {
1088 Self {
1089 backends: std::collections::HashMap::new(),
1090 default_backend: None,
1091 }
1092 }
1093
1094 pub fn register_plugin(&mut self, plugin: Box<dyn BackendPlugin>) -> BackendResult<()> {
1096 let name = plugin.name().to_string();
1097
1098 if !plugin.is_compatible() {
1100 return Err(TorshError::BackendError(format!(
1101 "Plugin {name} is not compatible with current system"
1102 )));
1103 }
1104
1105 self.backends.insert(name.clone(), plugin);
1106
1107 if self.default_backend.is_none() {
1109 self.default_backend = Some(name);
1110 }
1111
1112 Ok(())
1113 }
1114
1115 pub fn available_backends(&self) -> Vec<String> {
1117 self.backends.keys().cloned().collect()
1118 }
1119
1120 pub fn create_backend(&self, name: &str) -> BackendResult<Box<dyn Backend>> {
1122 if let Some(plugin) = self.backends.get(name) {
1123 plugin.create_backend()
1124 } else {
1125 Err(TorshError::BackendError(format!(
1126 "Backend {name} not found"
1127 )))
1128 }
1129 }
1130
1131 pub fn create_default_backend(&self) -> BackendResult<Box<dyn Backend>> {
1133 if let Some(default_name) = &self.default_backend {
1134 self.create_backend(default_name)
1135 } else {
1136 Err(TorshError::BackendError(
1137 "No default backend available".to_string(),
1138 ))
1139 }
1140 }
1141
1142 pub fn set_default_backend(&mut self, name: &str) -> BackendResult<()> {
1144 if self.backends.contains_key(name) {
1145 self.default_backend = Some(name.to_string());
1146 Ok(())
1147 } else {
1148 Err(TorshError::BackendError(format!(
1149 "Backend {name} not found"
1150 )))
1151 }
1152 }
1153
1154 pub fn get_plugin_metadata(&self, name: &str) -> Option<PluginMetadata> {
1156 self.backends.get(name).map(|plugin| plugin.metadata())
1157 }
1158}
1159
1160impl Default for BackendRegistry {
1161 fn default() -> Self {
1162 Self::new()
1163 }
1164}
1165
1166pub trait BackendConfig: Send + Sync + Clone {
1168 fn backend_type(&self) -> BackendType;
1170
1171 fn validate(&self) -> BackendResult<()>;
1173
1174 fn as_properties(&self) -> std::collections::HashMap<String, CapabilityValue>;
1176
1177 fn from_properties(
1179 properties: &std::collections::HashMap<String, CapabilityValue>,
1180 ) -> BackendResult<Self>
1181 where
1182 Self: Sized;
1183
1184 fn merge(&mut self, other: &Self) -> BackendResult<()>;
1186
1187 fn default_config() -> Self
1189 where
1190 Self: Sized;
1191}
1192
1193pub trait BackendBuilder<T: BackendConfig>: Send + Sync {
1195 fn new() -> Self;
1197
1198 fn with_config(self, config: T) -> Self;
1200
1201 fn build(self) -> BackendResult<Box<dyn Backend>>;
1203
1204 fn config(&self) -> &T;
1206
1207 fn config_mut(&mut self) -> &mut T;
1209}
1210
1211pub trait BackendErrorHandler: Send + Sync {
1213 fn handle_error(&self, error: TorshError, context: &str) -> TorshError;
1215
1216 fn convert_error(&self, error: Box<dyn std::error::Error + Send + Sync>) -> TorshError;
1218
1219 fn recovery_suggestions(&self, error: &TorshError) -> Vec<String>;
1221
1222 fn log_error(&self, error: &TorshError, context: &str);
1224}
1225
1226pub struct DefaultBackendErrorHandler {
1228 backend_name: String,
1229}
1230
1231impl DefaultBackendErrorHandler {
1232 pub fn new(backend_name: String) -> Self {
1233 Self { backend_name }
1234 }
1235}
1236
1237impl BackendErrorHandler for DefaultBackendErrorHandler {
1238 fn handle_error(&self, error: TorshError, context: &str) -> TorshError {
1239 match error {
1241 TorshError::BackendError(msg) => TorshError::BackendError(format!(
1242 "{}: {} (context: {})",
1243 self.backend_name, msg, context
1244 )),
1245 other => other,
1246 }
1247 }
1248
1249 fn convert_error(&self, error: Box<dyn std::error::Error + Send + Sync>) -> TorshError {
1250 TorshError::BackendError(format!("{}: {}", self.backend_name, error))
1251 }
1252
1253 fn recovery_suggestions(&self, error: &TorshError) -> Vec<String> {
1254 match error {
1255 TorshError::BackendError(msg) if msg.contains("not available") => {
1256 vec![
1257 "Check if the backend is properly installed".to_string(),
1258 "Verify system compatibility".to_string(),
1259 "Try a different backend".to_string(),
1260 ]
1261 }
1262 TorshError::BackendError(msg) if msg.contains("memory") => {
1263 vec![
1264 "Reduce batch size or tensor dimensions".to_string(),
1265 "Enable memory optimization".to_string(),
1266 "Check available memory".to_string(),
1267 ]
1268 }
1269 _ => vec!["Contact support with error details".to_string()],
1270 }
1271 }
1272
1273 fn log_error(&self, error: &TorshError, context: &str) {
1274 eprintln!("[{}] Error in {}: {}", self.backend_name, context, error);
1275 }
1276}
1277
1278impl dyn Backend {
1280 pub fn auto() -> BackendResult<Box<dyn Backend>> {
1282 let (device_type, _device) = DeviceEnumerator::find_best_device()?;
1283
1284 match device_type {
1285 #[cfg(feature = "cpu")]
1286 DeviceType::Cpu => Ok(Box::new(crate::cpu::CpuBackend::new()?)),
1287 #[cfg(cuda_available)]
1288 DeviceType::Cuda(device_id) => Ok(Box::new(crate::cuda::CudaBackend::new(
1289 crate::cuda::CudaBackendConfig {
1290 device_id: device_id as usize,
1291 ..Default::default()
1292 },
1293 )?)),
1294 #[cfg(all(feature = "cuda", not(cuda_available)))]
1295 DeviceType::Cuda(_) => Err(TorshError::BackendError(
1296 "CUDA backend not available on this platform".to_string(),
1297 )),
1298 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
1299 DeviceType::Metal(_) => Ok(Box::new(crate::metal::MetalBackend::new()?)),
1300 #[cfg(feature = "webgpu")]
1301 DeviceType::Wgpu(_) => {
1302 Ok(Box::new(crate::webgpu::WebGpuBackend::with_default_config()))
1303 }
1304 #[allow(unreachable_patterns)]
1305 _ => Err(TorshError::BackendError(
1306 "No suitable backend found".to_string(),
1307 )),
1308 }
1309 }
1310}