Skip to main content

torsh_backend/
lib.rs

1//! Unified backend implementation for ToRSh
2//!
3//! This crate provides a unified backend system that integrates with SciRS2's
4//! compute backends. All backend implementations are included in this single
5//! crate and selected via feature flags.
6//!
7//! # Features
8//!
9//! - `cpu` (default): CPU backend with SIMD optimizations via scirs2-core
10//! - `cuda`: NVIDIA GPU backend via scirs2-core's CUDA support
11//! - `metal`: Apple GPU backend via scirs2-core's Metal/MPS support
12//! - `rocm`: AMD GPU backend (when available in scirs2-core)
13//! - `webgpu`: WebGPU backend (when available in scirs2-core)
14
15#![cfg_attr(not(feature = "std"), no_std)]
16#![allow(clippy::too_many_arguments)]
17#![allow(clippy::uninlined_format_args)]
18#![allow(clippy::new_without_default)]
19#![allow(clippy::if_same_then_else)]
20#![allow(clippy::needless_range_loop)]
21#![allow(clippy::implicit_saturating_sub)]
22#![allow(clippy::unwrap_or_default)]
23#![allow(clippy::manual_div_ceil)]
24#![allow(clippy::wrong_self_convention)]
25#![allow(clippy::type_complexity)]
26#![allow(clippy::not_unsafe_ptr_arg_deref)]
27#![allow(clippy::inherent_to_string)]
28#![allow(clippy::derivable_impls)]
29#![allow(clippy::needless_borrows_for_generic_args)]
30#![allow(clippy::field_reassign_with_default)]
31#![allow(clippy::mut_from_ref)]
32#![allow(clippy::missing_transmute_annotations)]
33#![allow(clippy::should_implement_trait)]
34#![allow(clippy::redundant_closure)]
35#![allow(clippy::manual_flatten)]
36#![allow(clippy::useless_conversion)]
37#![allow(clippy::identity_op)]
38#![allow(clippy::len_without_is_empty)]
39#![allow(dead_code)]
40
41#[cfg(not(feature = "std"))]
42extern crate alloc;
43
44/// Backend-specific error types
45#[derive(Debug, Clone)]
46pub enum BackendError {
47    /// Invalid argument provided to backend operation
48    InvalidArgument(String),
49
50    /// Operation not supported by this backend
51    UnsupportedOperation(String),
52
53    /// Quantization-specific error
54    QuantizationError(String),
55
56    /// Invalid buffer state or operation
57    InvalidBuffer { message: String },
58
59    /// Runtime error during backend operation
60    Runtime { message: String },
61
62    /// Memory allocation error
63    AllocationFailed(String),
64
65    /// Device synchronization error
66    SynchronizationFailed(String),
67}
68
69impl std::fmt::Display for BackendError {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            BackendError::InvalidArgument(msg) => write!(f, "Invalid argument: {}", msg),
73            BackendError::UnsupportedOperation(msg) => write!(f, "Unsupported operation: {}", msg),
74            BackendError::QuantizationError(msg) => write!(f, "Quantization error: {}", msg),
75            BackendError::InvalidBuffer { message } => write!(f, "Invalid buffer: {}", message),
76            BackendError::Runtime { message } => write!(f, "Runtime error: {}", message),
77            BackendError::AllocationFailed(msg) => write!(f, "Allocation failed: {}", msg),
78            BackendError::SynchronizationFailed(msg) => {
79                write!(f, "Synchronization failed: {}", msg)
80            }
81        }
82    }
83}
84
85impl std::error::Error for BackendError {}
86
87// Core backend traits and types
88pub mod adaptive_kernel_selection;
89pub mod backend;
90pub mod buffer;
91pub mod convolution;
92pub mod cross_backend_transfer;
93pub mod cross_backend_validation;
94pub mod deadlock_prevention;
95pub mod device;
96pub mod error;
97pub mod fft;
98pub mod hardware_optimization_tests;
99pub mod introspection;
100pub mod jit_compiler;
101pub mod kernel;
102pub mod kernel_generation;
103pub mod memory;
104pub mod memory_defrag;
105pub mod memory_profiler;
106pub mod performance_modeling;
107pub mod performance_tuning;
108pub mod profiler;
109pub mod property_tests;
110pub mod quantization;
111pub mod rnn;
112pub mod sparse_ops;
113pub mod unified_memory_pool;
114pub mod version_compat;
115pub mod zero_copy;
116
117// Feature-gated backend implementations
118#[cfg(feature = "cpu")]
119pub mod cpu;
120
121#[cfg(feature = "cuda")]
122pub mod cuda;
123
124#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
125pub mod metal;
126
127#[cfg(feature = "rocm")]
128pub mod rocm;
129
130#[cfg(feature = "webgpu")]
131pub mod webgpu;
132
133// Re-exports
134pub use adaptive_kernel_selection::{
135    AdaptiveKernelSelector, AdaptiveSelectionConfig, BenchmarkResult, BenchmarkResults,
136    CustomKernel, HybridConfig, KernelCharacteristics, KernelConstraints, KernelExecutor,
137    KernelImplementation, KernelInputs, KernelOutputs, KernelParameter, KernelPerformanceRecord,
138    KernelRegistry, KernelSelection, KernelUsageStats, KernelVariant, MLBasedConfig, MLModelType,
139    MLTrainingParams, PerformanceTracker, ResourceRequirements, ScalabilityCharacteristics,
140    ScalingBehavior, ScoreBasedConfig, SelectionAccuracyTracker, SelectionAlgorithm,
141    SelectionReason, SelectionStatistics,
142};
143pub use backend::{
144    Backend, BackendCapabilities, BackendCore, BackendDeviceManager, BackendExecutor,
145    BackendExtension, BackendExtensionRegistry, BackendFactory, BackendLifecycle,
146    BackendOperations, BackendOps, BackendPlugin, BackendRegistry, BackendResourceManager,
147    BackendType, CapabilityValue, DeviceEnumerator, ExecutionModel, ExtendedCapabilities,
148    HardwareFeature, MemoryHierarchy, OperationsBundle, PerformanceHints, PluginMetadata,
149    PrecisionMode, ResourceLimits, ResourceStatistics, ResourceUsage, ScopedResource,
150};
151pub use buffer::{Buffer, BufferDescriptor, BufferHandle, BufferUsage, BufferView, MemoryLocation};
152
153/// Buffer error type (alias to BackendError)
154pub type BufferError = BackendError;
155pub use convolution::{
156    algorithms as conv_algorithms, ConvolutionAlgorithm, ConvolutionConfig, ConvolutionOps,
157    ConvolutionPerformanceHints, ConvolutionType, DefaultConvolutionOps, PaddingMode,
158};
159pub use cross_backend_transfer::CrossBackendTransferManager;
160pub use cross_backend_validation::{
161    compare_f32_values, compare_f64_values, run_cross_backend_validation, CrossBackendValidator,
162};
163pub use device::{
164    Device, DeviceConfiguration, DeviceDiscovery, DeviceFeature, DeviceInfo, DeviceManager,
165    DevicePerformanceInfo, DeviceRequirements, DeviceType, DeviceUtils,
166};
167pub use error::{BackendResult, ErrorCategory, ErrorContext, ErrorSeverity};
168pub use fft::{
169    convenience as fft_convenience, DefaultFftExecutor, DefaultFftOps, FftDirection, FftExecutor,
170    FftNormalization, FftOps, FftPlan, FftType,
171};
172pub use hardware_optimization_tests::{
173    run_hardware_optimization_tests, run_lightweight_hardware_tests, HardwareOptimizationTester,
174};
175pub use kernel::{Kernel, KernelDescriptor, KernelHandle, KernelLaunchConfig, KernelMetadata};
176pub use memory::{
177    AccessPattern, AllocationHint, AllocationLifetime, AllocationStrategy, CompactionResult,
178    DefragmentationPolicy, DefragmentationPriority, DefragmentationResult, DefragmentationStrategy,
179    FragmentationInfo, FragmentationSeverity, FreeListPool, LeakReport, LeakSeverity, LeakType,
180    MemoryAdvice, MemoryManager, MemoryManagerFactory, MemoryPool, MemoryPoolConfig, MemoryStats,
181    PoolStats,
182};
183pub use memory_defrag::{
184    CompactionPlan, DefragmentationManager, DefragmentationRequest, DefragmentationStats,
185    DefragmentationTask, MemoryBlock, MemoryLayout, TaskStatus,
186};
187pub use memory_profiler::{
188    AccessType, AllocationContext, AllocationUsageStats, HintSeverity, MemoryAllocation,
189    MemoryPressureEvent, MemoryProfiler, MemoryProfilerConfig, MemorySnapshot, MemoryType,
190    PerformanceHint, PerformanceHintType, PressureLevel,
191};
192pub use performance_modeling::{
193    AnomalyDetector, AnomalySeverity, AnomalyType, ComplexityClass, CorrelationAnalyzer,
194    CorrelationResult, EnvironmentalFactors, ModelAccuracy, ModelComplexity, ModelTrainingResult,
195    PatternType, PerformanceAnomaly, PerformanceCharacteristics, PerformanceMeasurement,
196    PerformanceModel, PerformanceReport, PerformanceSample, PerformanceTrend, RealtimeStatistics,
197    RuntimeMonitor, RuntimePerformanceModeler, TrendDirection, WorkloadPattern,
198};
199pub use performance_tuning::{
200    analyze_workload_optimization_opportunities,
201    create_default_constraints,
202    create_default_system_state,
203    create_energy_budget_constraints,
204    create_image_processing_workload,
205    create_ml_inference_workload,
206    create_ml_training_workload,
207    create_performance_optimized_system_state,
208    create_power_efficient_system_state,
209    create_realtime_constraints,
210    create_sample_workload,
211    create_throughput_constraints,
212    // Convenience functions
213    new_coordinator,
214    recommend_backend,
215    AccessPattern as PerfAccessPattern,
216    ActualPerformance,
217    BackendTuningStrategy,
218
219    DataType,
220    GlobalPerformanceStats,
221
222    MemoryAllocationStrategy,
223    NumaTopologyState,
224    // Configuration enums
225    OperationType,
226    OptimizationLevel,
227    PerformanceFeedback,
228    // Performance measurement and feedback
229    PerformancePrediction,
230    // Core coordination types
231    PerformanceTuningCoordinator,
232    PowerEfficiencyMode,
233    PowerState,
234    SchedulingStrategy,
235    StrategyMetrics,
236    SystemState,
237    ThermalState,
238    TuningConstraints,
239    TuningParameters,
240
241    TuningRecommendation,
242    TuningValue,
243
244    // Workload and system characteristics
245    WorkloadCharacteristics,
246};
247pub use profiler::{Profiler, ProfilerEvent, ProfilerStats, SimpleProfiler};
248pub use quantization::{
249    CalibrationMethod, QuantizationCalibrator, QuantizationHardwareFeatures, QuantizationOps,
250    QuantizationParams, QuantizationScheme, QuantizedDType, QuantizedTensor, SimdQuantizationOps,
251};
252pub use rnn::{
253    activations as rnn_activations, cells as rnn_cells, DefaultRnnOps, RnnActivation, RnnCellType,
254    RnnConfig, RnnDirection, RnnOps, RnnOutput, RnnPerformanceHints,
255};
256pub use sparse_ops::{
257    DefaultSparseOps, SparseFormat, SparseFormatConverter, SparseMatrix, SparseOperation,
258    SparseOps, SparseOptimizationHints,
259};
260pub use unified_memory_pool::{
261    CpuMemoryPool, CudaMemoryPool, MetalMemoryPool, RocmMemoryPool, UnifiedMemoryPool,
262    WebGpuMemoryPool,
263};
264pub use version_compat::{
265    BackendDependency, CompatibilityReport, DependencyStatus, Version, VersionCompatibilityChecker,
266    VersionError, VersionErrorContextExt, VersionRange,
267};
268pub use zero_copy::{
269    TransferDirection, TransferMode, ZeroCopyCapabilities, ZeroCopyManager, ZeroCopyStats,
270    ZeroCopyTransfer,
271};
272
273// Version information
274pub const VERSION: &str = env!("CARGO_PKG_VERSION");
275pub const VERSION_MAJOR: u32 = 0;
276pub const VERSION_MINOR: u32 = 1;
277pub const VERSION_PATCH: u32 = 0;
278
279/// Check if the CUDA backend is available
280///
281/// This is a convenience function primarily used in tests to gate
282/// CUDA-specific test execution.
283#[cfg(feature = "cuda")]
284pub fn is_available() -> bool {
285    cuda::is_available()
286}
287
288/// Check if any GPU backend is available (always false without CUDA feature)
289#[cfg(not(feature = "cuda"))]
290pub fn is_available() -> bool {
291    false
292}
293
294// SciRS2 integration re-exports
295#[cfg(feature = "cpu")]
296pub use cpu::{prepare_tensor_data, prepare_tensor_data_mut, SciRS2CpuBackend};
297use torsh_core::error::TorshError;
298
299// Removed unused imports: DType, Device as CoreDevice, Shape
300
301#[cfg(not(feature = "std"))]
302use alloc::{boxed::Box, vec::Vec};
303
304/// Unified backend builder
305pub struct BackendBuilder {
306    backend_type: BackendType,
307    device_id: usize,
308    memory_pool_config: Option<MemoryPoolConfig>,
309    num_threads: Option<usize>,
310    enable_profiling: bool,
311}
312
313impl Default for BackendBuilder {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319impl BackendBuilder {
320    /// Create a new backend builder
321    pub fn new() -> Self {
322        Self {
323            backend_type: BackendType::Auto,
324            device_id: 0,
325            memory_pool_config: None,
326            num_threads: None,
327            enable_profiling: false,
328        }
329    }
330
331    /// Set the backend type
332    pub fn backend_type(mut self, backend_type: BackendType) -> Self {
333        self.backend_type = backend_type;
334        self
335    }
336
337    /// Set the device ID
338    pub fn device_id(mut self, device_id: usize) -> Self {
339        self.device_id = device_id;
340        self
341    }
342
343    /// Set memory pool configuration
344    pub fn memory_pool(mut self, config: MemoryPoolConfig) -> Self {
345        self.memory_pool_config = Some(config);
346        self
347    }
348
349    /// Set number of threads (CPU backend)
350    pub fn num_threads(mut self, num_threads: usize) -> Self {
351        self.num_threads = Some(num_threads);
352        self
353    }
354
355    /// Enable profiling
356    pub fn enable_profiling(mut self, enable: bool) -> Self {
357        self.enable_profiling = enable;
358        self
359    }
360
361    /// Build the backend
362    pub fn build(self) -> BackendResult<Box<dyn Backend>> {
363        match self.backend_type {
364            BackendType::Auto => Self::auto_select(self),
365            BackendType::Cpu => Self::build_cpu(self),
366            BackendType::Cuda => Self::build_cuda(self),
367            BackendType::Metal => Self::build_metal(self),
368            BackendType::Rocm => Self::build_rocm(self),
369            BackendType::WebGpu => Self::build_webgpu(self),
370        }
371    }
372
373    fn auto_select(builder: Self) -> BackendResult<Box<dyn Backend>> {
374        // Try backends in order of preference
375        #[cfg(feature = "cuda")]
376        if let Ok(backend) = Self::build_cuda(builder.clone()) {
377            return Ok(backend);
378        }
379
380        #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
381        if let Ok(backend) = Self::build_metal(builder.clone()) {
382            return Ok(backend);
383        }
384
385        #[cfg(feature = "rocm")]
386        if let Ok(backend) = Self::build_rocm(builder.clone()) {
387            return Ok(backend);
388        }
389
390        #[cfg(feature = "webgpu")]
391        if let Ok(backend) = Self::build_webgpu(builder.clone()) {
392            return Ok(backend);
393        }
394
395        // Fall back to CPU
396        Self::build_cpu(builder)
397    }
398
399    #[cfg(feature = "cpu")]
400    fn build_cpu(builder: Self) -> BackendResult<Box<dyn Backend>> {
401        let mut cpu_builder = cpu::CpuBackend::builder();
402
403        if let Some(num_threads) = builder.num_threads {
404            cpu_builder = cpu_builder.num_threads(num_threads);
405        }
406
407        if let Some(pool_config) = builder.memory_pool_config {
408            cpu_builder = cpu_builder.memory_pool(pool_config);
409        }
410
411        Ok(Box::new(cpu_builder.build()?))
412    }
413
414    #[cfg(not(feature = "cpu"))]
415    fn build_cpu(_builder: Self) -> BackendResult<Box<dyn Backend>> {
416        Err(TorshError::BackendError("CPU backend not enabled".into()))
417    }
418
419    #[cfg(feature = "cuda")]
420    fn build_cuda(builder: Self) -> BackendResult<Box<dyn Backend>> {
421        let mut cuda_builder = cuda::CudaBackend::builder();
422
423        cuda_builder = cuda_builder.device(builder.device_id);
424
425        if let Some(pool_config) = builder.memory_pool_config {
426            cuda_builder = cuda_builder.memory_pool(pool_config);
427        }
428
429        Ok(Box::new(cuda_builder.build()?))
430    }
431
432    #[cfg(not(feature = "cuda"))]
433    fn build_cuda(_builder: Self) -> BackendResult<Box<dyn Backend>> {
434        Err(TorshError::BackendError("CUDA backend not enabled".into()))
435    }
436
437    #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
438    fn build_metal(builder: Self) -> BackendResult<Box<dyn Backend>> {
439        let mut metal_builder = metal::MetalBackend::builder();
440
441        if let Some(pool_config) = builder.memory_pool_config {
442            metal_builder = metal_builder.memory_pool(pool_config);
443        }
444
445        Ok(Box::new(metal_builder.build()?))
446    }
447
448    #[cfg(not(all(feature = "metal", target_os = "macos", target_arch = "aarch64")))]
449    fn build_metal(_builder: Self) -> BackendResult<Box<dyn Backend>> {
450        Err(TorshError::BackendError("Metal backend not enabled".into()))
451    }
452
453    #[cfg(feature = "rocm")]
454    fn build_rocm(_builder: Self) -> BackendResult<Box<dyn Backend>> {
455        // TODO: Implement when scirs2 supports ROCm
456        Err(TorshError::BackendError(
457            "ROCm backend not yet implemented".into(),
458        ))
459    }
460
461    #[cfg(not(feature = "rocm"))]
462    fn build_rocm(_builder: Self) -> BackendResult<Box<dyn Backend>> {
463        Err(TorshError::BackendError("ROCm backend not enabled".into()))
464    }
465
466    #[cfg(feature = "webgpu")]
467    fn build_webgpu(builder: Self) -> BackendResult<Box<dyn Backend>> {
468        let mut webgpu_builder = webgpu::WebGpuBackendBuilder::new();
469
470        // Set device ID
471        webgpu_builder = webgpu_builder.device_id(builder.device_id);
472
473        if let Some(pool_config) = builder.memory_pool_config {
474            if let Some(max_size) = pool_config.max_size {
475                webgpu_builder = webgpu_builder.max_buffer_size(max_size as u64);
476            }
477        }
478
479        webgpu_builder = webgpu_builder.enable_pipeline_cache(true);
480
481        Ok(Box::new(webgpu_builder.build()))
482    }
483
484    #[cfg(not(feature = "webgpu"))]
485    fn build_webgpu(_builder: Self) -> BackendResult<Box<dyn Backend>> {
486        Err(TorshError::BackendError(
487            "WebGPU backend not enabled".into(),
488        ))
489    }
490}
491
492impl Clone for BackendBuilder {
493    fn clone(&self) -> Self {
494        Self {
495            backend_type: self.backend_type,
496            device_id: self.device_id,
497            memory_pool_config: self.memory_pool_config.clone(),
498            num_threads: self.num_threads,
499            enable_profiling: self.enable_profiling,
500        }
501    }
502}
503
504/// Create a backend with automatic selection
505pub fn auto() -> BackendResult<Box<dyn Backend>> {
506    BackendBuilder::new().build()
507}
508
509/// Create a CPU backend
510pub fn cpu() -> BackendResult<Box<dyn Backend>> {
511    BackendBuilder::new().backend_type(BackendType::Cpu).build()
512}
513
514/// Create a CUDA backend
515pub fn cuda() -> BackendResult<Box<dyn Backend>> {
516    BackendBuilder::new()
517        .backend_type(BackendType::Cuda)
518        .build()
519}
520
521/// Create a Metal backend
522pub fn metal() -> BackendResult<Box<dyn Backend>> {
523    BackendBuilder::new()
524        .backend_type(BackendType::Metal)
525        .build()
526}
527
528/// List available backend types
529#[allow(clippy::vec_init_then_push)]
530pub fn available_backends() -> Vec<BackendType> {
531    let mut backends = vec![];
532
533    #[cfg(feature = "cpu")]
534    backends.push(BackendType::Cpu);
535
536    #[cfg(feature = "cuda")]
537    if cuda::is_available() {
538        backends.push(BackendType::Cuda);
539    }
540
541    #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
542    if metal::is_available() {
543        backends.push(BackendType::Metal);
544    }
545
546    #[cfg(feature = "rocm")]
547    if rocm::is_available() {
548        backends.push(BackendType::Rocm);
549    }
550
551    #[cfg(feature = "webgpu")]
552    if webgpu::is_available() {
553        backends.push(BackendType::WebGpu);
554    }
555
556    backends
557}
558
559/// Comprehensive device enumeration across all available backends
560pub fn enumerate_all_devices() -> BackendResult<Vec<(BackendType, Vec<Device>)>> {
561    let mut all_devices = Vec::new();
562
563    // Enumerate CPU devices
564    #[cfg(feature = "cpu")]
565    {
566        match cpu() {
567            Ok(backend) => {
568                if let Ok(devices) = backend.devices() {
569                    all_devices.push((BackendType::Cpu, devices));
570                }
571            }
572            Err(_) => {
573                // CPU backend failed, continue to other backends
574            }
575        }
576    }
577
578    // Enumerate CUDA devices
579    #[cfg(feature = "cuda")]
580    if cuda::is_available() {
581        // Try to enumerate multiple CUDA devices
582        for device_id in 0..cuda::device_count().unwrap_or(0) {
583            match BackendBuilder::new()
584                .backend_type(BackendType::Cuda)
585                .device_id(device_id as usize)
586                .build()
587            {
588                Ok(backend) => {
589                    if let Ok(devices) = backend.devices() {
590                        all_devices.push((BackendType::Cuda, devices));
591                        break; // For now, just get the first available CUDA backend
592                    }
593                }
594                Err(_) => continue,
595            }
596        }
597    }
598
599    // Enumerate Metal devices
600    #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
601    if metal::is_available() {
602        match BackendBuilder::new()
603            .backend_type(BackendType::Metal)
604            .build()
605        {
606            Ok(backend) => {
607                if let Ok(devices) = backend.devices() {
608                    // Only add if we actually have devices
609                    if !devices.is_empty() {
610                        all_devices.push((BackendType::Metal, devices));
611                    }
612                }
613            }
614            Err(_) => {
615                // Metal backend failed, continue
616            }
617        }
618    }
619
620    // Enumerate WebGPU devices
621    #[cfg(feature = "webgpu")]
622    if webgpu::is_available() {
623        match BackendBuilder::new()
624            .backend_type(BackendType::WebGpu)
625            .build()
626        {
627            Ok(backend) => {
628                if let Ok(devices) = backend.devices() {
629                    // Only add if we actually have devices
630                    if !devices.is_empty() {
631                        all_devices.push((BackendType::WebGpu, devices));
632                    }
633                }
634            }
635            Err(_) => {
636                // WebGPU backend failed, continue
637            }
638        }
639    }
640
641    Ok(all_devices)
642}
643
644/// Find the best available device based on selection criteria
645pub fn find_best_device(
646    selector: Option<device::DeviceSelector>,
647) -> BackendResult<(BackendType, Device)> {
648    let all_devices = enumerate_all_devices()?;
649
650    if all_devices.is_empty() {
651        return Err(TorshError::BackendError("No devices available".into()));
652    }
653
654    let selector = selector.unwrap_or_default();
655
656    // First pass: try to find an exact match
657    for (backend_type, devices) in &all_devices {
658        for device in devices {
659            if selector.matches(device) {
660                return Ok((*backend_type, device.clone()));
661            }
662        }
663    }
664
665    // Second pass: fallback to best available device with preference order
666    let preference_order = [
667        BackendType::Cuda,
668        BackendType::Metal,
669        BackendType::WebGpu,
670        BackendType::Cpu,
671    ];
672
673    for preferred_backend in &preference_order {
674        for (backend_type, devices) in &all_devices {
675            if backend_type == preferred_backend && !devices.is_empty() {
676                return Ok((*backend_type, devices[0].clone()));
677            }
678        }
679    }
680
681    // Final fallback: return the first available device
682    let (backend_type, devices) = &all_devices[0];
683    Ok((*backend_type, devices[0].clone()))
684}
685
686/// Get device count for a specific backend type
687pub fn device_count(backend_type: BackendType) -> BackendResult<usize> {
688    match backend_type {
689        BackendType::Cpu => Ok(1), // CPU backend always has 1 logical device
690
691        #[cfg(feature = "cuda")]
692        BackendType::Cuda => {
693            if cuda::is_available() {
694                Ok(cuda::device_count().unwrap_or(0) as usize)
695            } else {
696                Ok(0)
697            }
698        }
699
700        #[cfg(not(feature = "cuda"))]
701        BackendType::Cuda => Ok(0),
702
703        #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
704        BackendType::Metal => {
705            if metal::is_available() {
706                Ok(metal::device_count().unwrap_or(0))
707            } else {
708                Ok(0)
709            }
710        }
711
712        #[cfg(not(all(feature = "metal", target_os = "macos", target_arch = "aarch64")))]
713        BackendType::Metal => Ok(0),
714
715        #[cfg(feature = "webgpu")]
716        BackendType::WebGpu => {
717            if webgpu::is_available() {
718                Ok(webgpu::device_count().unwrap_or(0))
719            } else {
720                Ok(0)
721            }
722        }
723
724        #[cfg(not(feature = "webgpu"))]
725        BackendType::WebGpu => Ok(0),
726
727        BackendType::Rocm => Ok(0), // Not implemented yet
728        BackendType::Auto => {
729            // For Auto, return the sum of all available devices
730            let mut total = 0;
731            for backend in available_backends() {
732                if backend != BackendType::Auto {
733                    total += device_count(backend)?;
734                }
735            }
736            Ok(total)
737        }
738    }
739}
740
741/// Prelude module for convenient imports
742pub mod prelude {
743    pub use crate::{
744        auto,
745        available_backends,
746        compare_f32_values,
747        compare_f64_values,
748        cpu,
749        cuda,
750        device_count,
751        enumerate_all_devices,
752        find_best_device,
753        metal,
754        run_cross_backend_validation,
755        run_hardware_optimization_tests,
756        run_lightweight_hardware_tests,
757        AdaptiveKernelSelector,
758        Backend,
759        BackendBuilder,
760        BackendCapabilities,
761        BackendOps,
762        BackendPlugin,
763        BackendRegistry,
764        BackendResourceManager,
765        BackendResult,
766        BackendType,
767        BenchmarkResult,
768        Buffer,
769        CompactionPlan,
770        CrossBackendValidator,
771        DefragmentationManager,
772        DefragmentationStats,
773        Device,
774        ExecutionModel,
775        ExtendedCapabilities,
776        HardwareFeature,
777        HardwareOptimizationTester,
778        KernelImplementation,
779        KernelSelection,
780        KernelVariant,
781        MemoryHierarchy,
782        MemoryPool,
783        OperationType,
784        PerformanceMeasurement,
785        PerformancePrediction,
786        PerformanceReport,
787        PerformanceTrend,
788        PerformanceTuningCoordinator,
789        PluginMetadata,
790        PrecisionMode,
791        ResourceLimits,
792        ResourceStatistics,
793        ResourceUsage,
794        RuntimePerformanceModeler,
795        SelectionAlgorithm,
796        TransferDirection,
797        TransferMode,
798        TuningParameters,
799        TuningRecommendation,
800        WorkloadCharacteristics,
801        ZeroCopyCapabilities,
802        ZeroCopyManager,
803        ZeroCopyStats,
804        ZeroCopyTransfer,
805        // Version information
806        VERSION,
807        VERSION_MAJOR,
808        VERSION_MINOR,
809        VERSION_PATCH,
810    };
811}
812
813#[cfg(test)]
814mod tests {
815    use super::*;
816
817    #[test]
818    fn test_backend_builder() {
819        let builder = BackendBuilder::new()
820            .backend_type(BackendType::Cpu)
821            .device_id(0);
822
823        // Should successfully build CPU backend without specifying num_threads
824        // to avoid Rayon global thread pool conflicts in tests
825        let result = builder.build();
826        if let Err(e) = &result {
827            eprintln!("Backend build failed: {:?}", e);
828        }
829        assert!(result.is_ok());
830    }
831
832    #[test]
833    fn test_available_backends() {
834        let backends = available_backends();
835        // At least CPU should be available
836        assert!(!backends.is_empty());
837        assert!(backends.contains(&BackendType::Cpu));
838    }
839
840    #[test]
841    fn test_device_count() {
842        // CPU should always have at least 1 device
843        assert_eq!(device_count(BackendType::Cpu).unwrap(), 1);
844
845        // Auto should return total of all devices
846        let auto_count = device_count(BackendType::Auto).unwrap();
847        assert!(auto_count >= 1); // At least CPU
848
849        // Other backends depend on availability
850        for backend_type in available_backends() {
851            if backend_type != BackendType::Auto {
852                let count = device_count(backend_type).unwrap();
853                assert!(count < usize::MAX); // Should not fail
854            }
855        }
856    }
857
858    #[test]
859    fn test_enumerate_all_devices() {
860        let devices = enumerate_all_devices().unwrap();
861        assert!(!devices.is_empty()); // At least CPU should be available
862
863        // Check that CPU backend is present
864        let has_cpu = devices
865            .iter()
866            .any(|(backend_type, _)| *backend_type == BackendType::Cpu);
867        assert!(has_cpu);
868
869        // Verify each backend has at least one device
870        for (backend_type, device_list) in &devices {
871            assert!(
872                !device_list.is_empty(),
873                "Backend {:?} should have at least one device",
874                backend_type
875            );
876        }
877    }
878
879    #[test]
880    fn test_find_best_device() {
881        let (backend_type, device) = find_best_device(None).unwrap();
882
883        // Should find some device
884        assert!(matches!(
885            backend_type,
886            BackendType::Cpu | BackendType::Cuda | BackendType::Metal | BackendType::WebGpu
887        ));
888
889        // Device should be valid
890        assert!(!device.name().is_empty());
891    }
892
893    #[test]
894    fn test_find_best_device_with_selector() {
895        use crate::device::{DeviceSelector, DeviceType};
896
897        // Try to find a CPU device specifically
898        let selector = DeviceSelector::new().with_device_type(DeviceType::Cpu);
899        let result = find_best_device(Some(selector));
900
901        assert!(result.is_ok());
902        let (backend_type, device) = result.unwrap();
903        assert_eq!(backend_type, BackendType::Cpu);
904        assert_eq!(device.device_type(), torsh_core::device::DeviceType::Cpu);
905    }
906
907    #[test]
908    fn test_unified_error_handling() {
909        use crate::error::{conversion, ErrorContext};
910
911        // Test error context creation
912        let context = ErrorContext::new("test_operation")
913            .with_backend("TestBackend")
914            .with_device("test:0")
915            .with_details("test details");
916
917        let formatted = context.format();
918        assert!(formatted.contains("test_operation"));
919        assert!(formatted.contains("backend: TestBackend"));
920        assert!(formatted.contains("device: test:0"));
921        assert!(formatted.contains("details: test details"));
922
923        // Test error conversion utilities
924        let cuda_error =
925            conversion::cuda_error_with_context("Test CUDA error", "test_kernel", Some(0));
926        let error_str = cuda_error.to_string();
927        assert!(error_str.contains("CUDA"));
928        assert!(error_str.contains("test_kernel"));
929        assert!(error_str.contains("cuda:0"));
930
931        let cpu_error = conversion::cpu_error_with_context("Test CPU error", "test_operation");
932        let error_str = cpu_error.to_string();
933        assert!(error_str.contains("CPU"));
934        assert!(error_str.contains("test_operation"));
935
936        // Test memory error conversion
937        let memory_error =
938            conversion::memory_error_with_context("Out of memory", 1024, "CUDA", Some("cuda:0"));
939        let error_str = memory_error.to_string();
940        assert!(error_str.contains("memory_allocation"));
941        assert!(error_str.contains("1024 bytes"));
942        assert!(error_str.contains("CUDA"));
943        assert!(error_str.contains("cuda:0"));
944    }
945
946    #[test]
947    fn test_error_context_extension() {
948        // use crate::error::ErrorContextExt; // Currently unused
949        use torsh_core::error::TorshError;
950
951        // Test adding context to an error
952        let result: Result<(), TorshError> =
953            Err(TorshError::ComputeError("Test error".to_string()));
954        let with_context = crate::error::ErrorContextExt::with_operation(result, "test_operation");
955
956        assert!(with_context.is_err());
957        let error_str = with_context.unwrap_err().to_string();
958        assert!(error_str.contains("test_operation"));
959        assert!(error_str.contains("Test error"));
960    }
961
962    // ========== EDGE CASE AND ERROR CONDITION TESTS ==========
963
964    #[test]
965    fn test_invalid_device_id_error() {
966        // Test requesting a device ID that doesn't exist
967        let builder = BackendBuilder::new()
968            .backend_type(BackendType::Cpu)
969            .device_id(999); // CPU only has device 0
970
971        let backend = builder.build().unwrap();
972        let result = backend.create_device(999);
973        assert!(result.is_err());
974
975        // Verify error message is descriptive
976        let error_str = result.unwrap_err().to_string();
977        assert!(error_str.contains("999"));
978        assert!(error_str.contains("not found"));
979    }
980
981    #[test]
982    fn test_backend_builder_invalid_thread_count() {
983        // Test edge case: zero threads
984        let builder = BackendBuilder::new()
985            .backend_type(BackendType::Cpu)
986            .num_threads(0);
987
988        // Should still succeed but fall back to reasonable defaults
989        let result = builder.build();
990        assert!(result.is_ok());
991    }
992
993    #[test]
994    fn test_backend_builder_extreme_thread_count() {
995        // Test edge case: extremely high thread count
996        let builder = BackendBuilder::new()
997            .backend_type(BackendType::Cpu)
998            .num_threads(10000);
999
1000        // Should handle gracefully (Rayon will cap to reasonable limits)
1001        let result = builder.build();
1002        if let Err(ref e) = result {
1003            eprintln!("Backend build failed with extreme thread count: {:?}", e);
1004        }
1005        assert!(result.is_ok());
1006    }
1007
1008    #[test]
1009    fn test_unavailable_backend_selection() {
1010        // Test requesting backends that aren't compiled in
1011        #[cfg(not(feature = "cuda"))]
1012        {
1013            let builder = BackendBuilder::new().backend_type(BackendType::Cuda);
1014            let result = builder.build();
1015            assert!(result.is_err());
1016
1017            let error_str = result.unwrap_err().to_string();
1018            assert!(error_str.contains("not enabled"));
1019        }
1020
1021        #[cfg(not(feature = "metal"))]
1022        {
1023            let builder = BackendBuilder::new().backend_type(BackendType::Metal);
1024            let result = builder.build();
1025            assert!(result.is_err());
1026
1027            let error_str = result.unwrap_err().to_string();
1028            assert!(error_str.contains("not enabled"));
1029        }
1030    }
1031
1032    #[test]
1033    fn test_device_count_edge_cases() {
1034        // Test device count for unavailable backends
1035        #[cfg(not(feature = "cuda"))]
1036        {
1037            let count = device_count(BackendType::Cuda).unwrap();
1038            assert_eq!(count, 0);
1039        }
1040
1041        #[cfg(not(feature = "metal"))]
1042        {
1043            let count = device_count(BackendType::Metal).unwrap();
1044            assert_eq!(count, 0);
1045        }
1046
1047        // Test ROCm (always unavailable currently)
1048        let count = device_count(BackendType::Rocm).unwrap();
1049        assert_eq!(count, 0);
1050    }
1051
1052    #[test]
1053    fn test_find_best_device_no_match() {
1054        use crate::device::{DeviceSelector, DeviceType};
1055
1056        // Try to find a device that doesn't exist
1057        // This should still return a device (fallback behavior)
1058        let selector = DeviceSelector::new().with_device_type(DeviceType::Cuda);
1059        let result = find_best_device(Some(selector));
1060
1061        // Should still return a device (CPU fallback)
1062        assert!(result.is_ok());
1063    }
1064
1065    #[test]
1066    fn test_memory_pool_config_edge_cases() {
1067        // Test memory pool with extreme values
1068        let config = MemoryPoolConfig::new(0); // Zero initial size
1069        assert_eq!(config.initial_size, 0);
1070
1071        let config = MemoryPoolConfig::new(usize::MAX); // Maximum size
1072        assert_eq!(config.initial_size, usize::MAX);
1073
1074        // Test with invalid growth factor
1075        let config = MemoryPoolConfig::new(1024).with_growth_factor(0.0);
1076        assert_eq!(config.growth_factor, 0.0); // Should accept but may cause issues
1077
1078        let config = MemoryPoolConfig::new(1024).with_growth_factor(-1.0);
1079        assert_eq!(config.growth_factor, -1.0); // Should accept but may cause issues
1080    }
1081
1082    #[test]
1083    fn test_memory_pool_config_alignment_edge_cases() {
1084        // Test alignment edge cases
1085        let config = MemoryPoolConfig::new(1024).with_alignment(0);
1086        assert_eq!(config.alignment, 0); // Invalid alignment
1087
1088        let config = MemoryPoolConfig::new(1024).with_alignment(1);
1089        assert_eq!(config.alignment, 1); // Minimal alignment
1090
1091        let config = MemoryPoolConfig::new(1024).with_alignment(4096);
1092        assert_eq!(config.alignment, 4096); // Page-aligned
1093    }
1094
1095    #[test]
1096    fn test_error_handling_with_long_messages() {
1097        use crate::error::conversion;
1098
1099        // Test error handling with very long error messages
1100        let long_message = "x".repeat(10000);
1101        let error = conversion::cpu_error_with_context(long_message.clone(), "test_operation");
1102
1103        let error_str = error.to_string();
1104        assert!(error_str.contains(&long_message));
1105        assert!(error_str.len() > 10000);
1106    }
1107
1108    #[test]
1109    fn test_error_handling_with_special_characters() {
1110        use crate::error::conversion;
1111
1112        // Test error handling with special characters
1113        let special_message = "Error: 測試 ñoño 🚀 \n\t\r";
1114        let error = conversion::cpu_error_with_context(special_message, "test_unicode_operation");
1115
1116        let error_str = error.to_string();
1117        assert!(error_str.contains("測試"));
1118        assert!(error_str.contains("🚀"));
1119    }
1120
1121    #[test]
1122    fn test_concurrent_backend_creation() {
1123        use std::sync::atomic::{AtomicUsize, Ordering};
1124        use std::sync::Arc;
1125        use std::thread;
1126
1127        // Test creating multiple backends concurrently
1128        let success_count = Arc::new(AtomicUsize::new(0));
1129        let error_count = Arc::new(AtomicUsize::new(0));
1130
1131        let mut handles = vec![];
1132
1133        for _ in 0..10 {
1134            let success_count = Arc::clone(&success_count);
1135            let error_count = Arc::clone(&error_count);
1136
1137            let handle = thread::spawn(move || {
1138                let builder = BackendBuilder::new().backend_type(BackendType::Cpu);
1139                match builder.build() {
1140                    Ok(_) => success_count.fetch_add(1, Ordering::Relaxed),
1141                    Err(_) => error_count.fetch_add(1, Ordering::Relaxed),
1142                };
1143            });
1144
1145            handles.push(handle);
1146        }
1147
1148        for handle in handles {
1149            handle.join().unwrap();
1150        }
1151
1152        // At least some should succeed (thread pool initialization might cause some to fail)
1153        let successes = success_count.load(Ordering::Relaxed);
1154        assert!(
1155            successes > 0,
1156            "No backend creation succeeded in concurrent test"
1157        );
1158    }
1159
1160    #[test]
1161    fn test_backend_memory_pressure_simulation() {
1162        // Test backend behavior under simulated memory pressure
1163        let backend = BackendBuilder::new()
1164            .backend_type(BackendType::Cpu)
1165            .memory_pool(MemoryPoolConfig::new(1024)) // Very small pool
1166            .build()
1167            .unwrap();
1168
1169        // This should succeed
1170        let device = backend.default_device().unwrap();
1171        assert!(!device.name().is_empty());
1172    }
1173
1174    #[test]
1175    fn test_enumerate_devices_consistency() {
1176        // Test that device enumeration is consistent across multiple calls
1177        let devices1 = enumerate_all_devices().unwrap();
1178        let devices2 = enumerate_all_devices().unwrap();
1179
1180        // Should return the same number of backends
1181        assert_eq!(devices1.len(), devices2.len());
1182
1183        // Should return the same backend types
1184        let backend_types1: std::collections::HashSet<_> =
1185            devices1.iter().map(|(bt, _)| *bt).collect();
1186        let backend_types2: std::collections::HashSet<_> =
1187            devices2.iter().map(|(bt, _)| *bt).collect();
1188        assert_eq!(backend_types1, backend_types2);
1189    }
1190
1191    #[test]
1192    fn test_device_selector_empty_criteria() {
1193        use crate::device::DeviceSelector;
1194
1195        // Test device selector with no criteria (should match any device)
1196        let selector = DeviceSelector::new();
1197        let result = find_best_device(Some(selector));
1198        assert!(result.is_ok());
1199    }
1200
1201    #[test]
1202    fn test_backend_builder_chain_operations() {
1203        // Test method chaining with all possible configurations
1204        let builder = BackendBuilder::new()
1205            .backend_type(BackendType::Cpu)
1206            .device_id(0)
1207            .num_threads(4)
1208            .memory_pool(MemoryPoolConfig::new(1024 * 1024))
1209            .enable_profiling(true);
1210
1211        let result = builder.build();
1212        assert!(result.is_ok());
1213    }
1214
1215    #[test]
1216    fn test_auto_backend_selection_fallback() {
1217        // Test that auto selection properly falls back through preference order
1218        let builder = BackendBuilder::new().backend_type(BackendType::Auto);
1219        let result = builder.build();
1220
1221        // Should always succeed (CPU fallback)
1222        assert!(result.is_ok());
1223
1224        let backend = result.unwrap();
1225
1226        // Should have at least one device
1227        let devices = backend.devices().unwrap();
1228        assert!(!devices.is_empty());
1229    }
1230
1231    // ========== ADDITIONAL EDGE CASE TESTS ==========
1232
1233    #[test]
1234    fn test_memory_pool_zero_max_size() {
1235        // Test memory pool with zero max size
1236        let config = MemoryPoolConfig::new(1024).with_max_size(0);
1237        let builder = BackendBuilder::new()
1238            .backend_type(BackendType::Cpu)
1239            .memory_pool(config);
1240
1241        // Should handle gracefully (may succeed or fail depending on implementation)
1242        let result = builder.build();
1243        // Don't assert success/failure - implementation defined behavior
1244        match result {
1245            Ok(_) => {
1246                // Success is acceptable
1247            }
1248            Err(_) => {
1249                // Failure is also acceptable for zero max size
1250            }
1251        }
1252    }
1253
1254    #[test]
1255    fn test_memory_pool_negative_growth_factor() {
1256        // Test memory pool with negative growth factor
1257        let config = MemoryPoolConfig::new(1024).with_growth_factor(-0.5);
1258        let builder = BackendBuilder::new()
1259            .backend_type(BackendType::Cpu)
1260            .memory_pool(config);
1261
1262        // Should handle gracefully
1263        let result = builder.build();
1264        // Implementation may accept or reject negative growth factors
1265        match result {
1266            Ok(_) => {
1267                // May accept and handle internally
1268            }
1269            Err(_) => {
1270                // May reject as invalid configuration
1271            }
1272        }
1273    }
1274
1275    #[test]
1276    fn test_device_selector_with_conflicting_criteria() {
1277        use crate::device::{DeviceSelector, DeviceType};
1278
1279        // Test device selector with conflicting criteria
1280        let selector = DeviceSelector::new()
1281            .with_device_type(DeviceType::Cpu)
1282            .with_device_type(DeviceType::Cuda); // Conflicting requirements
1283
1284        let result = find_best_device(Some(selector));
1285        // Should still return a device (last criterion wins or fallback)
1286        assert!(result.is_ok());
1287    }
1288
1289    #[test]
1290    fn test_backend_builder_cloning_with_modifications() {
1291        // Test cloning builder and modifying the clone
1292        let original_builder = BackendBuilder::new()
1293            .backend_type(BackendType::Cpu)
1294            .num_threads(2);
1295
1296        let mut cloned_builder = original_builder.clone();
1297        cloned_builder = cloned_builder.num_threads(4);
1298
1299        // Both should be independent
1300        let original_result = original_builder.build();
1301        let cloned_result = cloned_builder.build();
1302
1303        assert!(original_result.is_ok());
1304        assert!(cloned_result.is_ok());
1305    }
1306
1307    #[test]
1308    fn test_error_context_with_empty_strings() {
1309        use crate::error::ErrorContext;
1310
1311        // Test error context with empty strings
1312        let context = ErrorContext::new("")
1313            .with_backend("")
1314            .with_device("")
1315            .with_details("");
1316
1317        let formatted = context.format();
1318        // Should not panic and should handle empty strings gracefully
1319        assert!(!formatted.is_empty());
1320    }
1321
1322    #[test]
1323    fn test_error_context_with_null_characters() {
1324        use crate::error::ErrorContext;
1325
1326        // Test error context with null characters and control characters
1327        let context = ErrorContext::new("op\0eration")
1328            .with_backend("back\0end")
1329            .with_device("dev\0ice")
1330            .with_details("deta\0ils");
1331
1332        let formatted = context.format();
1333        // Should handle null characters without panicking
1334        assert!(!formatted.is_empty());
1335    }
1336
1337    #[test]
1338    fn test_memory_manager_extreme_alignment() {
1339        // Test memory pool with extreme alignment values
1340        let config = MemoryPoolConfig::new(1024).with_alignment(usize::MAX);
1341        let builder = BackendBuilder::new()
1342            .backend_type(BackendType::Cpu)
1343            .memory_pool(config);
1344
1345        // Should handle extreme alignment gracefully
1346        let result = builder.build();
1347        // Implementation may accept or reject extreme alignment
1348        match result {
1349            Ok(_) => {
1350                // May clamp to reasonable values
1351            }
1352            Err(_) => {
1353                // May reject as invalid
1354            }
1355        }
1356    }
1357
1358    #[test]
1359    fn test_backend_resource_cleanup() {
1360        // Test that backends properly clean up resources
1361        let backend = BackendBuilder::new()
1362            .backend_type(BackendType::Cpu)
1363            .build()
1364            .unwrap();
1365
1366        // Use the backend for operations
1367        let _device = backend.default_device().unwrap();
1368        let _devices = backend.devices().unwrap();
1369
1370        // Drop the backend explicitly
1371        drop(backend);
1372
1373        // Should not leak resources (verified by memory leak detection tools)
1374        // This test mainly ensures no panics during cleanup
1375    }
1376
1377    #[test]
1378    fn test_available_backends_consistency() {
1379        // Test that available_backends() returns consistent results
1380        let backends1 = available_backends();
1381        let backends2 = available_backends();
1382
1383        // Should return the same backends
1384        assert_eq!(backends1, backends2);
1385
1386        // Should always include CPU
1387        assert!(backends1.contains(&BackendType::Cpu));
1388
1389        // Should not include Auto in the list
1390        assert!(!backends1.contains(&BackendType::Auto));
1391    }
1392
1393    #[test]
1394    fn test_device_count_consistency() {
1395        // Test that device_count() returns consistent results
1396        for backend_type in available_backends() {
1397            let count1 = device_count(backend_type).unwrap();
1398            let count2 = device_count(backend_type).unwrap();
1399
1400            assert_eq!(
1401                count1, count2,
1402                "Device count should be consistent for {:?}",
1403                backend_type
1404            );
1405        }
1406    }
1407
1408    #[test]
1409    fn test_enumerate_devices_with_no_backends() {
1410        // This test simulates the scenario where no backends are available
1411        // (can't actually disable all backends, but we can test the empty case handling)
1412        let devices = enumerate_all_devices().unwrap();
1413
1414        // Should never be empty since CPU is always available
1415        assert!(!devices.is_empty());
1416
1417        // But test that our logic handles empty cases in find_best_device
1418        // by testing the early return path
1419    }
1420
1421    #[test]
1422    fn test_backend_capability_reporting() {
1423        // Test that backends properly report their capabilities
1424        let backend = BackendBuilder::new()
1425            .backend_type(BackendType::Cpu)
1426            .build()
1427            .unwrap();
1428
1429        let capabilities = backend.capabilities();
1430
1431        // Should have some capabilities
1432        assert!(!capabilities.supported_dtypes.is_empty());
1433
1434        // CPU should support basic data types
1435        assert!(capabilities
1436            .supported_dtypes
1437            .contains(&torsh_core::DType::F32));
1438        assert!(capabilities
1439            .supported_dtypes
1440            .contains(&torsh_core::DType::F64));
1441    }
1442
1443    #[test]
1444    fn test_error_recovery_and_retry_logic() {
1445        // Test error recovery scenarios
1446        let mut retry_count = 0;
1447        let max_retries = 3;
1448
1449        loop {
1450            // Simulate an operation that might fail
1451            let result = BackendBuilder::new()
1452                .backend_type(BackendType::Cpu)
1453                .num_threads(1) // Use minimal threads to avoid conflicts
1454                .build();
1455
1456            match result {
1457                Ok(_) => {
1458                    // Success
1459                    break;
1460                }
1461                Err(e) => {
1462                    retry_count += 1;
1463                    if retry_count >= max_retries {
1464                        // Test that we can handle the error gracefully
1465                        let error_msg = e.to_string();
1466                        assert!(!error_msg.is_empty());
1467                        break;
1468                    }
1469                    // Simulate delay before retry
1470                    std::thread::sleep(std::time::Duration::from_millis(10));
1471                }
1472            }
1473        }
1474    }
1475
1476    #[test]
1477    fn test_backend_performance_hints() {
1478        // Test backend performance hints system
1479        let backend = BackendBuilder::new()
1480            .backend_type(BackendType::Cpu)
1481            .build()
1482            .unwrap();
1483
1484        let hints = backend.performance_hints();
1485
1486        // Should provide some hints
1487        assert!(hints.optimal_batch_size > 0);
1488
1489        // Hints should be reasonable
1490        assert!(hints.optimal_batch_size <= 1024 * 1024); // Should be reasonable
1491    }
1492
1493    #[test]
1494    fn test_cross_backend_type_compatibility() {
1495        // Test that different backend types can coexist
1496        let cpu_result = BackendBuilder::new().backend_type(BackendType::Cpu).build();
1497
1498        assert!(cpu_result.is_ok());
1499
1500        // Test other backends if available
1501        #[cfg(feature = "cuda")]
1502        {
1503            let cuda_result = BackendBuilder::new()
1504                .backend_type(BackendType::Cuda)
1505                .build();
1506
1507            // May succeed or fail depending on hardware
1508            match cuda_result {
1509                Ok(_) => {
1510                    // Both backends should be able to exist
1511                }
1512                Err(_) => {
1513                    // CUDA may not be available
1514                }
1515            }
1516        }
1517    }
1518
1519    #[test]
1520    fn test_backend_state_isolation() {
1521        // Test that different backend instances don't interfere
1522        let backend1 = BackendBuilder::new()
1523            .backend_type(BackendType::Cpu)
1524            .num_threads(2)
1525            .build()
1526            .unwrap();
1527
1528        let backend2 = BackendBuilder::new()
1529            .backend_type(BackendType::Cpu)
1530            .num_threads(4)
1531            .build()
1532            .unwrap();
1533
1534        // Both should work independently
1535        let device1 = backend1.default_device().unwrap();
1536        let device2 = backend2.default_device().unwrap();
1537
1538        assert!(!device1.name().is_empty());
1539        assert!(!device2.name().is_empty());
1540    }
1541
1542    #[test]
1543    fn test_profiling_enablement() {
1544        // Test backend with profiling enabled
1545        let backend = BackendBuilder::new()
1546            .backend_type(BackendType::Cpu)
1547            .enable_profiling(true)
1548            .build()
1549            .unwrap();
1550
1551        // Should successfully create backend with profiling
1552        let device = backend.default_device().unwrap();
1553        assert!(!device.name().is_empty());
1554
1555        // Test with profiling disabled
1556        let backend_no_prof = BackendBuilder::new()
1557            .backend_type(BackendType::Cpu)
1558            .enable_profiling(false)
1559            .build()
1560            .unwrap();
1561
1562        let device_no_prof = backend_no_prof.default_device().unwrap();
1563        assert!(!device_no_prof.name().is_empty());
1564    }
1565
1566    // ========== CROSS-BACKEND VALIDATION TESTS ==========
1567
1568    #[test]
1569    #[ignore = "Requires CUDA hardware - run with --ignored flag"]
1570    fn test_cross_backend_validation_integration() {
1571        use crate::cross_backend_validation::{
1572            run_cross_backend_validation, CrossBackendValidator,
1573        };
1574
1575        // Test validator creation
1576        let validator = CrossBackendValidator::new();
1577        assert!(!validator.available_backends().is_empty());
1578
1579        // Test individual validation components
1580        // Note: These may fail if some backends aren't available (e.g., missing framework classes)
1581        // which is acceptable in test environments
1582        match validator.validate_device_creation() {
1583            Ok(()) => {} // Validation passed
1584            Err(e) => eprintln!("Device creation validation warning: {}", e),
1585        }
1586        match validator.validate_capabilities_consistency() {
1587            Ok(()) => {} // Validation passed
1588            Err(e) => eprintln!("Capabilities consistency validation warning: {}", e),
1589        }
1590
1591        // Test full validation suite
1592        match run_cross_backend_validation() {
1593            Ok(()) => {
1594                // All validations passed
1595            }
1596            Err(e) => {
1597                // Some validation failed - log but don't fail the test
1598                // since some backends may not be available in CI
1599                eprintln!("Cross-backend validation warning: {}", e);
1600            }
1601        }
1602    }
1603
1604    #[test]
1605    fn test_floating_point_comparison_utilities() {
1606        use crate::cross_backend_validation::{compare_f32_values, compare_f64_values};
1607
1608        // Test normal comparisons
1609        assert!(compare_f32_values(1.0, 1.0, 1e-6));
1610        assert!(compare_f32_values(1.0, 1.0000005, 1e-6));
1611        assert!(!compare_f32_values(1.0, 1.1, 1e-6));
1612
1613        assert!(compare_f64_values(1.0, 1.0, 1e-11));
1614        assert!(compare_f64_values(1.0, 1.00000000001, 1.1e-11));
1615        assert!(!compare_f64_values(1.0, 1.1, 1e-11));
1616
1617        // Test special values
1618        assert!(compare_f32_values(f32::NAN, f32::NAN, 1e-6));
1619        assert!(compare_f32_values(f32::INFINITY, f32::INFINITY, 1e-6));
1620        assert!(!compare_f32_values(f32::INFINITY, f32::NEG_INFINITY, 1e-6));
1621
1622        assert!(compare_f64_values(f64::NAN, f64::NAN, 1e-12));
1623        assert!(compare_f64_values(f64::INFINITY, f64::INFINITY, 1e-12));
1624        assert!(!compare_f64_values(f64::INFINITY, f64::NEG_INFINITY, 1e-12));
1625    }
1626
1627    // ========== HARDWARE OPTIMIZATION TESTS ==========
1628
1629    #[test]
1630    fn test_hardware_optimization_integration() {
1631        use crate::hardware_optimization_tests::{
1632            run_lightweight_hardware_tests, HardwareOptimizationTester,
1633        };
1634
1635        // Test tester creation
1636        let tester = HardwareOptimizationTester::new();
1637        assert!(tester.simd_tests_enabled);
1638        assert!(tester.platform_tests_enabled);
1639        assert!(!tester.performance_tests_enabled); // Should be disabled by default
1640
1641        // Run lightweight tests (suitable for CI)
1642        match run_lightweight_hardware_tests() {
1643            Ok(()) => {
1644                // Tests passed
1645            }
1646            Err(e) => {
1647                // Log warning but don't fail test - hardware detection may not be available
1648                eprintln!("Hardware optimization tests warning: {}", e);
1649            }
1650        }
1651    }
1652
1653    #[test]
1654    fn test_hardware_optimization_tester_configuration() {
1655        use crate::hardware_optimization_tests::HardwareOptimizationTester;
1656
1657        // Test that we can configure the tester
1658        let mut tester = HardwareOptimizationTester::new();
1659
1660        // Modify configuration
1661        tester.simd_tests_enabled = false;
1662        tester.platform_tests_enabled = true;
1663        tester.performance_tests_enabled = false;
1664
1665        // Configuration should be applied
1666        assert!(!tester.simd_tests_enabled);
1667        assert!(tester.platform_tests_enabled);
1668        assert!(!tester.performance_tests_enabled);
1669    }
1670}