1#![cfg_attr(not(feature = "std"), no_std)]
16#![allow(clippy::too_many_arguments)]
17#![allow(clippy::uninlined_format_args)]
18#![allow(clippy::new_without_default)]
19#![allow(clippy::if_same_then_else)]
20#![allow(clippy::needless_range_loop)]
21#![allow(clippy::implicit_saturating_sub)]
22#![allow(clippy::unwrap_or_default)]
23#![allow(clippy::manual_div_ceil)]
24#![allow(clippy::wrong_self_convention)]
25#![allow(clippy::type_complexity)]
26#![allow(clippy::not_unsafe_ptr_arg_deref)]
27#![allow(clippy::inherent_to_string)]
28#![allow(clippy::derivable_impls)]
29#![allow(clippy::needless_borrows_for_generic_args)]
30#![allow(clippy::field_reassign_with_default)]
31#![allow(clippy::mut_from_ref)]
32#![allow(clippy::missing_transmute_annotations)]
33#![allow(clippy::should_implement_trait)]
34#![allow(clippy::redundant_closure)]
35#![allow(clippy::manual_flatten)]
36#![allow(clippy::useless_conversion)]
37#![allow(clippy::identity_op)]
38#![allow(clippy::len_without_is_empty)]
39#![allow(dead_code)]
40
41#[cfg(not(feature = "std"))]
42extern crate alloc;
43
44#[derive(Debug, Clone)]
46pub enum BackendError {
47 InvalidArgument(String),
49
50 UnsupportedOperation(String),
52
53 QuantizationError(String),
55
56 InvalidBuffer { message: String },
58
59 Runtime { message: String },
61
62 AllocationFailed(String),
64
65 SynchronizationFailed(String),
67}
68
69impl std::fmt::Display for BackendError {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 BackendError::InvalidArgument(msg) => write!(f, "Invalid argument: {}", msg),
73 BackendError::UnsupportedOperation(msg) => write!(f, "Unsupported operation: {}", msg),
74 BackendError::QuantizationError(msg) => write!(f, "Quantization error: {}", msg),
75 BackendError::InvalidBuffer { message } => write!(f, "Invalid buffer: {}", message),
76 BackendError::Runtime { message } => write!(f, "Runtime error: {}", message),
77 BackendError::AllocationFailed(msg) => write!(f, "Allocation failed: {}", msg),
78 BackendError::SynchronizationFailed(msg) => {
79 write!(f, "Synchronization failed: {}", msg)
80 }
81 }
82 }
83}
84
85impl std::error::Error for BackendError {}
86
87pub mod adaptive_kernel_selection;
89pub mod backend;
90pub mod buffer;
91pub mod convolution;
92pub mod cross_backend_transfer;
93pub mod cross_backend_validation;
94pub mod deadlock_prevention;
95pub mod device;
96pub mod error;
97pub mod fft;
98pub mod hardware_optimization_tests;
99pub mod introspection;
100pub mod jit_compiler;
101pub mod kernel;
102pub mod kernel_generation;
103pub mod memory;
104pub mod memory_defrag;
105pub mod memory_profiler;
106pub mod performance_modeling;
107pub mod performance_tuning;
108pub mod profiler;
109pub mod property_tests;
110pub mod quantization;
111pub mod rnn;
112pub mod sparse_ops;
113pub mod unified_memory_pool;
114pub mod version_compat;
115pub mod zero_copy;
116
117#[cfg(feature = "cpu")]
119pub mod cpu;
120
121#[cfg(feature = "cuda")]
122pub mod cuda;
123
124#[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
125pub mod metal;
126
127#[cfg(feature = "rocm")]
128pub mod rocm;
129
130#[cfg(feature = "webgpu")]
131pub mod webgpu;
132
133pub use adaptive_kernel_selection::{
135 AdaptiveKernelSelector, AdaptiveSelectionConfig, BenchmarkResult, BenchmarkResults,
136 CustomKernel, HybridConfig, KernelCharacteristics, KernelConstraints, KernelExecutor,
137 KernelImplementation, KernelInputs, KernelOutputs, KernelParameter, KernelPerformanceRecord,
138 KernelRegistry, KernelSelection, KernelUsageStats, KernelVariant, MLBasedConfig, MLModelType,
139 MLTrainingParams, PerformanceTracker, ResourceRequirements, ScalabilityCharacteristics,
140 ScalingBehavior, ScoreBasedConfig, SelectionAccuracyTracker, SelectionAlgorithm,
141 SelectionReason, SelectionStatistics,
142};
143pub use backend::{
144 Backend, BackendCapabilities, BackendCore, BackendDeviceManager, BackendExecutor,
145 BackendExtension, BackendExtensionRegistry, BackendFactory, BackendLifecycle,
146 BackendOperations, BackendOps, BackendPlugin, BackendRegistry, BackendResourceManager,
147 BackendType, CapabilityValue, DeviceEnumerator, ExecutionModel, ExtendedCapabilities,
148 HardwareFeature, MemoryHierarchy, OperationsBundle, PerformanceHints, PluginMetadata,
149 PrecisionMode, ResourceLimits, ResourceStatistics, ResourceUsage, ScopedResource,
150};
151pub use buffer::{Buffer, BufferDescriptor, BufferHandle, BufferUsage, BufferView, MemoryLocation};
152
153pub type BufferError = BackendError;
155pub use convolution::{
156 algorithms as conv_algorithms, ConvolutionAlgorithm, ConvolutionConfig, ConvolutionOps,
157 ConvolutionPerformanceHints, ConvolutionType, DefaultConvolutionOps, PaddingMode,
158};
159pub use cross_backend_transfer::CrossBackendTransferManager;
160pub use cross_backend_validation::{
161 compare_f32_values, compare_f64_values, run_cross_backend_validation, CrossBackendValidator,
162};
163pub use device::{
164 Device, DeviceConfiguration, DeviceDiscovery, DeviceFeature, DeviceInfo, DeviceManager,
165 DevicePerformanceInfo, DeviceRequirements, DeviceType, DeviceUtils,
166};
167pub use error::{BackendResult, ErrorCategory, ErrorContext, ErrorSeverity};
168pub use fft::{
169 convenience as fft_convenience, DefaultFftExecutor, DefaultFftOps, FftDirection, FftExecutor,
170 FftNormalization, FftOps, FftPlan, FftType,
171};
172pub use hardware_optimization_tests::{
173 run_hardware_optimization_tests, run_lightweight_hardware_tests, HardwareOptimizationTester,
174};
175pub use kernel::{Kernel, KernelDescriptor, KernelHandle, KernelLaunchConfig, KernelMetadata};
176pub use memory::{
177 AccessPattern, AllocationHint, AllocationLifetime, AllocationStrategy, CompactionResult,
178 DefragmentationPolicy, DefragmentationPriority, DefragmentationResult, DefragmentationStrategy,
179 FragmentationInfo, FragmentationSeverity, FreeListPool, LeakReport, LeakSeverity, LeakType,
180 MemoryAdvice, MemoryManager, MemoryManagerFactory, MemoryPool, MemoryPoolConfig, MemoryStats,
181 PoolStats,
182};
183pub use memory_defrag::{
184 CompactionPlan, DefragmentationManager, DefragmentationRequest, DefragmentationStats,
185 DefragmentationTask, MemoryBlock, MemoryLayout, TaskStatus,
186};
187pub use memory_profiler::{
188 AccessType, AllocationContext, AllocationUsageStats, HintSeverity, MemoryAllocation,
189 MemoryPressureEvent, MemoryProfiler, MemoryProfilerConfig, MemorySnapshot, MemoryType,
190 PerformanceHint, PerformanceHintType, PressureLevel,
191};
192pub use performance_modeling::{
193 AnomalyDetector, AnomalySeverity, AnomalyType, ComplexityClass, CorrelationAnalyzer,
194 CorrelationResult, EnvironmentalFactors, ModelAccuracy, ModelComplexity, ModelTrainingResult,
195 PatternType, PerformanceAnomaly, PerformanceCharacteristics, PerformanceMeasurement,
196 PerformanceModel, PerformanceReport, PerformanceSample, PerformanceTrend, RealtimeStatistics,
197 RuntimeMonitor, RuntimePerformanceModeler, TrendDirection, WorkloadPattern,
198};
199pub use performance_tuning::{
200 analyze_workload_optimization_opportunities,
201 create_default_constraints,
202 create_default_system_state,
203 create_energy_budget_constraints,
204 create_image_processing_workload,
205 create_ml_inference_workload,
206 create_ml_training_workload,
207 create_performance_optimized_system_state,
208 create_power_efficient_system_state,
209 create_realtime_constraints,
210 create_sample_workload,
211 create_throughput_constraints,
212 new_coordinator,
214 recommend_backend,
215 AccessPattern as PerfAccessPattern,
216 ActualPerformance,
217 BackendTuningStrategy,
218
219 DataType,
220 GlobalPerformanceStats,
221
222 MemoryAllocationStrategy,
223 NumaTopologyState,
224 OperationType,
226 OptimizationLevel,
227 PerformanceFeedback,
228 PerformancePrediction,
230 PerformanceTuningCoordinator,
232 PowerEfficiencyMode,
233 PowerState,
234 SchedulingStrategy,
235 StrategyMetrics,
236 SystemState,
237 ThermalState,
238 TuningConstraints,
239 TuningParameters,
240
241 TuningRecommendation,
242 TuningValue,
243
244 WorkloadCharacteristics,
246};
247pub use profiler::{Profiler, ProfilerEvent, ProfilerStats, SimpleProfiler};
248pub use quantization::{
249 CalibrationMethod, QuantizationCalibrator, QuantizationHardwareFeatures, QuantizationOps,
250 QuantizationParams, QuantizationScheme, QuantizedDType, QuantizedTensor, SimdQuantizationOps,
251};
252pub use rnn::{
253 activations as rnn_activations, cells as rnn_cells, DefaultRnnOps, RnnActivation, RnnCellType,
254 RnnConfig, RnnDirection, RnnOps, RnnOutput, RnnPerformanceHints,
255};
256pub use sparse_ops::{
257 DefaultSparseOps, SparseFormat, SparseFormatConverter, SparseMatrix, SparseOperation,
258 SparseOps, SparseOptimizationHints,
259};
260pub use unified_memory_pool::{
261 CpuMemoryPool, CudaMemoryPool, MetalMemoryPool, RocmMemoryPool, UnifiedMemoryPool,
262 WebGpuMemoryPool,
263};
264pub use version_compat::{
265 BackendDependency, CompatibilityReport, DependencyStatus, Version, VersionCompatibilityChecker,
266 VersionError, VersionErrorContextExt, VersionRange,
267};
268pub use zero_copy::{
269 TransferDirection, TransferMode, ZeroCopyCapabilities, ZeroCopyManager, ZeroCopyStats,
270 ZeroCopyTransfer,
271};
272
273pub const VERSION: &str = env!("CARGO_PKG_VERSION");
275pub const VERSION_MAJOR: u32 = 0;
276pub const VERSION_MINOR: u32 = 1;
277pub const VERSION_PATCH: u32 = 0;
278
279#[cfg(feature = "cuda")]
284pub fn is_available() -> bool {
285 cuda::is_available()
286}
287
288#[cfg(not(feature = "cuda"))]
290pub fn is_available() -> bool {
291 false
292}
293
294#[cfg(feature = "cpu")]
296pub use cpu::{prepare_tensor_data, prepare_tensor_data_mut, SciRS2CpuBackend};
297use torsh_core::error::TorshError;
298
299#[cfg(not(feature = "std"))]
302use alloc::{boxed::Box, vec::Vec};
303
304pub struct BackendBuilder {
306 backend_type: BackendType,
307 device_id: usize,
308 memory_pool_config: Option<MemoryPoolConfig>,
309 num_threads: Option<usize>,
310 enable_profiling: bool,
311}
312
313impl Default for BackendBuilder {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319impl BackendBuilder {
320 pub fn new() -> Self {
322 Self {
323 backend_type: BackendType::Auto,
324 device_id: 0,
325 memory_pool_config: None,
326 num_threads: None,
327 enable_profiling: false,
328 }
329 }
330
331 pub fn backend_type(mut self, backend_type: BackendType) -> Self {
333 self.backend_type = backend_type;
334 self
335 }
336
337 pub fn device_id(mut self, device_id: usize) -> Self {
339 self.device_id = device_id;
340 self
341 }
342
343 pub fn memory_pool(mut self, config: MemoryPoolConfig) -> Self {
345 self.memory_pool_config = Some(config);
346 self
347 }
348
349 pub fn num_threads(mut self, num_threads: usize) -> Self {
351 self.num_threads = Some(num_threads);
352 self
353 }
354
355 pub fn enable_profiling(mut self, enable: bool) -> Self {
357 self.enable_profiling = enable;
358 self
359 }
360
361 pub fn build(self) -> BackendResult<Box<dyn Backend>> {
363 match self.backend_type {
364 BackendType::Auto => Self::auto_select(self),
365 BackendType::Cpu => Self::build_cpu(self),
366 BackendType::Cuda => Self::build_cuda(self),
367 BackendType::Metal => Self::build_metal(self),
368 BackendType::Rocm => Self::build_rocm(self),
369 BackendType::WebGpu => Self::build_webgpu(self),
370 }
371 }
372
373 fn auto_select(builder: Self) -> BackendResult<Box<dyn Backend>> {
374 #[cfg(feature = "cuda")]
376 if let Ok(backend) = Self::build_cuda(builder.clone()) {
377 return Ok(backend);
378 }
379
380 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
381 if let Ok(backend) = Self::build_metal(builder.clone()) {
382 return Ok(backend);
383 }
384
385 #[cfg(feature = "rocm")]
386 if let Ok(backend) = Self::build_rocm(builder.clone()) {
387 return Ok(backend);
388 }
389
390 #[cfg(feature = "webgpu")]
391 if let Ok(backend) = Self::build_webgpu(builder.clone()) {
392 return Ok(backend);
393 }
394
395 Self::build_cpu(builder)
397 }
398
399 #[cfg(feature = "cpu")]
400 fn build_cpu(builder: Self) -> BackendResult<Box<dyn Backend>> {
401 let mut cpu_builder = cpu::CpuBackend::builder();
402
403 if let Some(num_threads) = builder.num_threads {
404 cpu_builder = cpu_builder.num_threads(num_threads);
405 }
406
407 if let Some(pool_config) = builder.memory_pool_config {
408 cpu_builder = cpu_builder.memory_pool(pool_config);
409 }
410
411 Ok(Box::new(cpu_builder.build()?))
412 }
413
414 #[cfg(not(feature = "cpu"))]
415 fn build_cpu(_builder: Self) -> BackendResult<Box<dyn Backend>> {
416 Err(TorshError::BackendError("CPU backend not enabled".into()))
417 }
418
419 #[cfg(feature = "cuda")]
420 fn build_cuda(builder: Self) -> BackendResult<Box<dyn Backend>> {
421 let mut cuda_builder = cuda::CudaBackend::builder();
422
423 cuda_builder = cuda_builder.device(builder.device_id);
424
425 if let Some(pool_config) = builder.memory_pool_config {
426 cuda_builder = cuda_builder.memory_pool(pool_config);
427 }
428
429 Ok(Box::new(cuda_builder.build()?))
430 }
431
432 #[cfg(not(feature = "cuda"))]
433 fn build_cuda(_builder: Self) -> BackendResult<Box<dyn Backend>> {
434 Err(TorshError::BackendError("CUDA backend not enabled".into()))
435 }
436
437 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
438 fn build_metal(builder: Self) -> BackendResult<Box<dyn Backend>> {
439 let mut metal_builder = metal::MetalBackend::builder();
440
441 if let Some(pool_config) = builder.memory_pool_config {
442 metal_builder = metal_builder.memory_pool(pool_config);
443 }
444
445 Ok(Box::new(metal_builder.build()?))
446 }
447
448 #[cfg(not(all(feature = "metal", target_os = "macos", target_arch = "aarch64")))]
449 fn build_metal(_builder: Self) -> BackendResult<Box<dyn Backend>> {
450 Err(TorshError::BackendError("Metal backend not enabled".into()))
451 }
452
453 #[cfg(feature = "rocm")]
454 fn build_rocm(_builder: Self) -> BackendResult<Box<dyn Backend>> {
455 Err(TorshError::BackendError(
457 "ROCm backend not yet implemented".into(),
458 ))
459 }
460
461 #[cfg(not(feature = "rocm"))]
462 fn build_rocm(_builder: Self) -> BackendResult<Box<dyn Backend>> {
463 Err(TorshError::BackendError("ROCm backend not enabled".into()))
464 }
465
466 #[cfg(feature = "webgpu")]
467 fn build_webgpu(builder: Self) -> BackendResult<Box<dyn Backend>> {
468 let mut webgpu_builder = webgpu::WebGpuBackendBuilder::new();
469
470 webgpu_builder = webgpu_builder.device_id(builder.device_id);
472
473 if let Some(pool_config) = builder.memory_pool_config {
474 if let Some(max_size) = pool_config.max_size {
475 webgpu_builder = webgpu_builder.max_buffer_size(max_size as u64);
476 }
477 }
478
479 webgpu_builder = webgpu_builder.enable_pipeline_cache(true);
480
481 Ok(Box::new(webgpu_builder.build()))
482 }
483
484 #[cfg(not(feature = "webgpu"))]
485 fn build_webgpu(_builder: Self) -> BackendResult<Box<dyn Backend>> {
486 Err(TorshError::BackendError(
487 "WebGPU backend not enabled".into(),
488 ))
489 }
490}
491
492impl Clone for BackendBuilder {
493 fn clone(&self) -> Self {
494 Self {
495 backend_type: self.backend_type,
496 device_id: self.device_id,
497 memory_pool_config: self.memory_pool_config.clone(),
498 num_threads: self.num_threads,
499 enable_profiling: self.enable_profiling,
500 }
501 }
502}
503
504pub fn auto() -> BackendResult<Box<dyn Backend>> {
506 BackendBuilder::new().build()
507}
508
509pub fn cpu() -> BackendResult<Box<dyn Backend>> {
511 BackendBuilder::new().backend_type(BackendType::Cpu).build()
512}
513
514pub fn cuda() -> BackendResult<Box<dyn Backend>> {
516 BackendBuilder::new()
517 .backend_type(BackendType::Cuda)
518 .build()
519}
520
521pub fn metal() -> BackendResult<Box<dyn Backend>> {
523 BackendBuilder::new()
524 .backend_type(BackendType::Metal)
525 .build()
526}
527
528#[allow(clippy::vec_init_then_push)]
530pub fn available_backends() -> Vec<BackendType> {
531 let mut backends = vec![];
532
533 #[cfg(feature = "cpu")]
534 backends.push(BackendType::Cpu);
535
536 #[cfg(feature = "cuda")]
537 if cuda::is_available() {
538 backends.push(BackendType::Cuda);
539 }
540
541 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
542 if metal::is_available() {
543 backends.push(BackendType::Metal);
544 }
545
546 #[cfg(feature = "rocm")]
547 if rocm::is_available() {
548 backends.push(BackendType::Rocm);
549 }
550
551 #[cfg(feature = "webgpu")]
552 if webgpu::is_available() {
553 backends.push(BackendType::WebGpu);
554 }
555
556 backends
557}
558
559pub fn enumerate_all_devices() -> BackendResult<Vec<(BackendType, Vec<Device>)>> {
561 let mut all_devices = Vec::new();
562
563 #[cfg(feature = "cpu")]
565 {
566 match cpu() {
567 Ok(backend) => {
568 if let Ok(devices) = backend.devices() {
569 all_devices.push((BackendType::Cpu, devices));
570 }
571 }
572 Err(_) => {
573 }
575 }
576 }
577
578 #[cfg(feature = "cuda")]
580 if cuda::is_available() {
581 for device_id in 0..cuda::device_count().unwrap_or(0) {
583 match BackendBuilder::new()
584 .backend_type(BackendType::Cuda)
585 .device_id(device_id as usize)
586 .build()
587 {
588 Ok(backend) => {
589 if let Ok(devices) = backend.devices() {
590 all_devices.push((BackendType::Cuda, devices));
591 break; }
593 }
594 Err(_) => continue,
595 }
596 }
597 }
598
599 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
601 if metal::is_available() {
602 match BackendBuilder::new()
603 .backend_type(BackendType::Metal)
604 .build()
605 {
606 Ok(backend) => {
607 if let Ok(devices) = backend.devices() {
608 if !devices.is_empty() {
610 all_devices.push((BackendType::Metal, devices));
611 }
612 }
613 }
614 Err(_) => {
615 }
617 }
618 }
619
620 #[cfg(feature = "webgpu")]
622 if webgpu::is_available() {
623 match BackendBuilder::new()
624 .backend_type(BackendType::WebGpu)
625 .build()
626 {
627 Ok(backend) => {
628 if let Ok(devices) = backend.devices() {
629 if !devices.is_empty() {
631 all_devices.push((BackendType::WebGpu, devices));
632 }
633 }
634 }
635 Err(_) => {
636 }
638 }
639 }
640
641 Ok(all_devices)
642}
643
644pub fn find_best_device(
646 selector: Option<device::DeviceSelector>,
647) -> BackendResult<(BackendType, Device)> {
648 let all_devices = enumerate_all_devices()?;
649
650 if all_devices.is_empty() {
651 return Err(TorshError::BackendError("No devices available".into()));
652 }
653
654 let selector = selector.unwrap_or_default();
655
656 for (backend_type, devices) in &all_devices {
658 for device in devices {
659 if selector.matches(device) {
660 return Ok((*backend_type, device.clone()));
661 }
662 }
663 }
664
665 let preference_order = [
667 BackendType::Cuda,
668 BackendType::Metal,
669 BackendType::WebGpu,
670 BackendType::Cpu,
671 ];
672
673 for preferred_backend in &preference_order {
674 for (backend_type, devices) in &all_devices {
675 if backend_type == preferred_backend && !devices.is_empty() {
676 return Ok((*backend_type, devices[0].clone()));
677 }
678 }
679 }
680
681 let (backend_type, devices) = &all_devices[0];
683 Ok((*backend_type, devices[0].clone()))
684}
685
686pub fn device_count(backend_type: BackendType) -> BackendResult<usize> {
688 match backend_type {
689 BackendType::Cpu => Ok(1), #[cfg(feature = "cuda")]
692 BackendType::Cuda => {
693 if cuda::is_available() {
694 Ok(cuda::device_count().unwrap_or(0) as usize)
695 } else {
696 Ok(0)
697 }
698 }
699
700 #[cfg(not(feature = "cuda"))]
701 BackendType::Cuda => Ok(0),
702
703 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
704 BackendType::Metal => {
705 if metal::is_available() {
706 Ok(metal::device_count().unwrap_or(0))
707 } else {
708 Ok(0)
709 }
710 }
711
712 #[cfg(not(all(feature = "metal", target_os = "macos", target_arch = "aarch64")))]
713 BackendType::Metal => Ok(0),
714
715 #[cfg(feature = "webgpu")]
716 BackendType::WebGpu => {
717 if webgpu::is_available() {
718 Ok(webgpu::device_count().unwrap_or(0))
719 } else {
720 Ok(0)
721 }
722 }
723
724 #[cfg(not(feature = "webgpu"))]
725 BackendType::WebGpu => Ok(0),
726
727 BackendType::Rocm => Ok(0), BackendType::Auto => {
729 let mut total = 0;
731 for backend in available_backends() {
732 if backend != BackendType::Auto {
733 total += device_count(backend)?;
734 }
735 }
736 Ok(total)
737 }
738 }
739}
740
741pub mod prelude {
743 pub use crate::{
744 auto,
745 available_backends,
746 compare_f32_values,
747 compare_f64_values,
748 cpu,
749 cuda,
750 device_count,
751 enumerate_all_devices,
752 find_best_device,
753 metal,
754 run_cross_backend_validation,
755 run_hardware_optimization_tests,
756 run_lightweight_hardware_tests,
757 AdaptiveKernelSelector,
758 Backend,
759 BackendBuilder,
760 BackendCapabilities,
761 BackendOps,
762 BackendPlugin,
763 BackendRegistry,
764 BackendResourceManager,
765 BackendResult,
766 BackendType,
767 BenchmarkResult,
768 Buffer,
769 CompactionPlan,
770 CrossBackendValidator,
771 DefragmentationManager,
772 DefragmentationStats,
773 Device,
774 ExecutionModel,
775 ExtendedCapabilities,
776 HardwareFeature,
777 HardwareOptimizationTester,
778 KernelImplementation,
779 KernelSelection,
780 KernelVariant,
781 MemoryHierarchy,
782 MemoryPool,
783 OperationType,
784 PerformanceMeasurement,
785 PerformancePrediction,
786 PerformanceReport,
787 PerformanceTrend,
788 PerformanceTuningCoordinator,
789 PluginMetadata,
790 PrecisionMode,
791 ResourceLimits,
792 ResourceStatistics,
793 ResourceUsage,
794 RuntimePerformanceModeler,
795 SelectionAlgorithm,
796 TransferDirection,
797 TransferMode,
798 TuningParameters,
799 TuningRecommendation,
800 WorkloadCharacteristics,
801 ZeroCopyCapabilities,
802 ZeroCopyManager,
803 ZeroCopyStats,
804 ZeroCopyTransfer,
805 VERSION,
807 VERSION_MAJOR,
808 VERSION_MINOR,
809 VERSION_PATCH,
810 };
811}
812
813#[cfg(test)]
814mod tests {
815 use super::*;
816
817 #[test]
818 fn test_backend_builder() {
819 let builder = BackendBuilder::new()
820 .backend_type(BackendType::Cpu)
821 .device_id(0);
822
823 let result = builder.build();
826 if let Err(e) = &result {
827 eprintln!("Backend build failed: {:?}", e);
828 }
829 assert!(result.is_ok());
830 }
831
832 #[test]
833 fn test_available_backends() {
834 let backends = available_backends();
835 assert!(!backends.is_empty());
837 assert!(backends.contains(&BackendType::Cpu));
838 }
839
840 #[test]
841 fn test_device_count() {
842 assert_eq!(device_count(BackendType::Cpu).unwrap(), 1);
844
845 let auto_count = device_count(BackendType::Auto).unwrap();
847 assert!(auto_count >= 1); for backend_type in available_backends() {
851 if backend_type != BackendType::Auto {
852 let count = device_count(backend_type).unwrap();
853 assert!(count < usize::MAX); }
855 }
856 }
857
858 #[test]
859 fn test_enumerate_all_devices() {
860 let devices = enumerate_all_devices().unwrap();
861 assert!(!devices.is_empty()); let has_cpu = devices
865 .iter()
866 .any(|(backend_type, _)| *backend_type == BackendType::Cpu);
867 assert!(has_cpu);
868
869 for (backend_type, device_list) in &devices {
871 assert!(
872 !device_list.is_empty(),
873 "Backend {:?} should have at least one device",
874 backend_type
875 );
876 }
877 }
878
879 #[test]
880 fn test_find_best_device() {
881 let (backend_type, device) = find_best_device(None).unwrap();
882
883 assert!(matches!(
885 backend_type,
886 BackendType::Cpu | BackendType::Cuda | BackendType::Metal | BackendType::WebGpu
887 ));
888
889 assert!(!device.name().is_empty());
891 }
892
893 #[test]
894 fn test_find_best_device_with_selector() {
895 use crate::device::{DeviceSelector, DeviceType};
896
897 let selector = DeviceSelector::new().with_device_type(DeviceType::Cpu);
899 let result = find_best_device(Some(selector));
900
901 assert!(result.is_ok());
902 let (backend_type, device) = result.unwrap();
903 assert_eq!(backend_type, BackendType::Cpu);
904 assert_eq!(device.device_type(), torsh_core::device::DeviceType::Cpu);
905 }
906
907 #[test]
908 fn test_unified_error_handling() {
909 use crate::error::{conversion, ErrorContext};
910
911 let context = ErrorContext::new("test_operation")
913 .with_backend("TestBackend")
914 .with_device("test:0")
915 .with_details("test details");
916
917 let formatted = context.format();
918 assert!(formatted.contains("test_operation"));
919 assert!(formatted.contains("backend: TestBackend"));
920 assert!(formatted.contains("device: test:0"));
921 assert!(formatted.contains("details: test details"));
922
923 let cuda_error =
925 conversion::cuda_error_with_context("Test CUDA error", "test_kernel", Some(0));
926 let error_str = cuda_error.to_string();
927 assert!(error_str.contains("CUDA"));
928 assert!(error_str.contains("test_kernel"));
929 assert!(error_str.contains("cuda:0"));
930
931 let cpu_error = conversion::cpu_error_with_context("Test CPU error", "test_operation");
932 let error_str = cpu_error.to_string();
933 assert!(error_str.contains("CPU"));
934 assert!(error_str.contains("test_operation"));
935
936 let memory_error =
938 conversion::memory_error_with_context("Out of memory", 1024, "CUDA", Some("cuda:0"));
939 let error_str = memory_error.to_string();
940 assert!(error_str.contains("memory_allocation"));
941 assert!(error_str.contains("1024 bytes"));
942 assert!(error_str.contains("CUDA"));
943 assert!(error_str.contains("cuda:0"));
944 }
945
946 #[test]
947 fn test_error_context_extension() {
948 use torsh_core::error::TorshError;
950
951 let result: Result<(), TorshError> =
953 Err(TorshError::ComputeError("Test error".to_string()));
954 let with_context = crate::error::ErrorContextExt::with_operation(result, "test_operation");
955
956 assert!(with_context.is_err());
957 let error_str = with_context.unwrap_err().to_string();
958 assert!(error_str.contains("test_operation"));
959 assert!(error_str.contains("Test error"));
960 }
961
962 #[test]
965 fn test_invalid_device_id_error() {
966 let builder = BackendBuilder::new()
968 .backend_type(BackendType::Cpu)
969 .device_id(999); let backend = builder.build().unwrap();
972 let result = backend.create_device(999);
973 assert!(result.is_err());
974
975 let error_str = result.unwrap_err().to_string();
977 assert!(error_str.contains("999"));
978 assert!(error_str.contains("not found"));
979 }
980
981 #[test]
982 fn test_backend_builder_invalid_thread_count() {
983 let builder = BackendBuilder::new()
985 .backend_type(BackendType::Cpu)
986 .num_threads(0);
987
988 let result = builder.build();
990 assert!(result.is_ok());
991 }
992
993 #[test]
994 fn test_backend_builder_extreme_thread_count() {
995 let builder = BackendBuilder::new()
997 .backend_type(BackendType::Cpu)
998 .num_threads(10000);
999
1000 let result = builder.build();
1002 if let Err(ref e) = result {
1003 eprintln!("Backend build failed with extreme thread count: {:?}", e);
1004 }
1005 assert!(result.is_ok());
1006 }
1007
1008 #[test]
1009 fn test_unavailable_backend_selection() {
1010 #[cfg(not(feature = "cuda"))]
1012 {
1013 let builder = BackendBuilder::new().backend_type(BackendType::Cuda);
1014 let result = builder.build();
1015 assert!(result.is_err());
1016
1017 let error_str = result.unwrap_err().to_string();
1018 assert!(error_str.contains("not enabled"));
1019 }
1020
1021 #[cfg(not(feature = "metal"))]
1022 {
1023 let builder = BackendBuilder::new().backend_type(BackendType::Metal);
1024 let result = builder.build();
1025 assert!(result.is_err());
1026
1027 let error_str = result.unwrap_err().to_string();
1028 assert!(error_str.contains("not enabled"));
1029 }
1030 }
1031
1032 #[test]
1033 fn test_device_count_edge_cases() {
1034 #[cfg(not(feature = "cuda"))]
1036 {
1037 let count = device_count(BackendType::Cuda).unwrap();
1038 assert_eq!(count, 0);
1039 }
1040
1041 #[cfg(not(feature = "metal"))]
1042 {
1043 let count = device_count(BackendType::Metal).unwrap();
1044 assert_eq!(count, 0);
1045 }
1046
1047 let count = device_count(BackendType::Rocm).unwrap();
1049 assert_eq!(count, 0);
1050 }
1051
1052 #[test]
1053 fn test_find_best_device_no_match() {
1054 use crate::device::{DeviceSelector, DeviceType};
1055
1056 let selector = DeviceSelector::new().with_device_type(DeviceType::Cuda);
1059 let result = find_best_device(Some(selector));
1060
1061 assert!(result.is_ok());
1063 }
1064
1065 #[test]
1066 fn test_memory_pool_config_edge_cases() {
1067 let config = MemoryPoolConfig::new(0); assert_eq!(config.initial_size, 0);
1070
1071 let config = MemoryPoolConfig::new(usize::MAX); assert_eq!(config.initial_size, usize::MAX);
1073
1074 let config = MemoryPoolConfig::new(1024).with_growth_factor(0.0);
1076 assert_eq!(config.growth_factor, 0.0); let config = MemoryPoolConfig::new(1024).with_growth_factor(-1.0);
1079 assert_eq!(config.growth_factor, -1.0); }
1081
1082 #[test]
1083 fn test_memory_pool_config_alignment_edge_cases() {
1084 let config = MemoryPoolConfig::new(1024).with_alignment(0);
1086 assert_eq!(config.alignment, 0); let config = MemoryPoolConfig::new(1024).with_alignment(1);
1089 assert_eq!(config.alignment, 1); let config = MemoryPoolConfig::new(1024).with_alignment(4096);
1092 assert_eq!(config.alignment, 4096); }
1094
1095 #[test]
1096 fn test_error_handling_with_long_messages() {
1097 use crate::error::conversion;
1098
1099 let long_message = "x".repeat(10000);
1101 let error = conversion::cpu_error_with_context(long_message.clone(), "test_operation");
1102
1103 let error_str = error.to_string();
1104 assert!(error_str.contains(&long_message));
1105 assert!(error_str.len() > 10000);
1106 }
1107
1108 #[test]
1109 fn test_error_handling_with_special_characters() {
1110 use crate::error::conversion;
1111
1112 let special_message = "Error: 測試 ñoño 🚀 \n\t\r";
1114 let error = conversion::cpu_error_with_context(special_message, "test_unicode_operation");
1115
1116 let error_str = error.to_string();
1117 assert!(error_str.contains("測試"));
1118 assert!(error_str.contains("🚀"));
1119 }
1120
1121 #[test]
1122 fn test_concurrent_backend_creation() {
1123 use std::sync::atomic::{AtomicUsize, Ordering};
1124 use std::sync::Arc;
1125 use std::thread;
1126
1127 let success_count = Arc::new(AtomicUsize::new(0));
1129 let error_count = Arc::new(AtomicUsize::new(0));
1130
1131 let mut handles = vec![];
1132
1133 for _ in 0..10 {
1134 let success_count = Arc::clone(&success_count);
1135 let error_count = Arc::clone(&error_count);
1136
1137 let handle = thread::spawn(move || {
1138 let builder = BackendBuilder::new().backend_type(BackendType::Cpu);
1139 match builder.build() {
1140 Ok(_) => success_count.fetch_add(1, Ordering::Relaxed),
1141 Err(_) => error_count.fetch_add(1, Ordering::Relaxed),
1142 };
1143 });
1144
1145 handles.push(handle);
1146 }
1147
1148 for handle in handles {
1149 handle.join().unwrap();
1150 }
1151
1152 let successes = success_count.load(Ordering::Relaxed);
1154 assert!(
1155 successes > 0,
1156 "No backend creation succeeded in concurrent test"
1157 );
1158 }
1159
1160 #[test]
1161 fn test_backend_memory_pressure_simulation() {
1162 let backend = BackendBuilder::new()
1164 .backend_type(BackendType::Cpu)
1165 .memory_pool(MemoryPoolConfig::new(1024)) .build()
1167 .unwrap();
1168
1169 let device = backend.default_device().unwrap();
1171 assert!(!device.name().is_empty());
1172 }
1173
1174 #[test]
1175 fn test_enumerate_devices_consistency() {
1176 let devices1 = enumerate_all_devices().unwrap();
1178 let devices2 = enumerate_all_devices().unwrap();
1179
1180 assert_eq!(devices1.len(), devices2.len());
1182
1183 let backend_types1: std::collections::HashSet<_> =
1185 devices1.iter().map(|(bt, _)| *bt).collect();
1186 let backend_types2: std::collections::HashSet<_> =
1187 devices2.iter().map(|(bt, _)| *bt).collect();
1188 assert_eq!(backend_types1, backend_types2);
1189 }
1190
1191 #[test]
1192 fn test_device_selector_empty_criteria() {
1193 use crate::device::DeviceSelector;
1194
1195 let selector = DeviceSelector::new();
1197 let result = find_best_device(Some(selector));
1198 assert!(result.is_ok());
1199 }
1200
1201 #[test]
1202 fn test_backend_builder_chain_operations() {
1203 let builder = BackendBuilder::new()
1205 .backend_type(BackendType::Cpu)
1206 .device_id(0)
1207 .num_threads(4)
1208 .memory_pool(MemoryPoolConfig::new(1024 * 1024))
1209 .enable_profiling(true);
1210
1211 let result = builder.build();
1212 assert!(result.is_ok());
1213 }
1214
1215 #[test]
1216 fn test_auto_backend_selection_fallback() {
1217 let builder = BackendBuilder::new().backend_type(BackendType::Auto);
1219 let result = builder.build();
1220
1221 assert!(result.is_ok());
1223
1224 let backend = result.unwrap();
1225
1226 let devices = backend.devices().unwrap();
1228 assert!(!devices.is_empty());
1229 }
1230
1231 #[test]
1234 fn test_memory_pool_zero_max_size() {
1235 let config = MemoryPoolConfig::new(1024).with_max_size(0);
1237 let builder = BackendBuilder::new()
1238 .backend_type(BackendType::Cpu)
1239 .memory_pool(config);
1240
1241 let result = builder.build();
1243 match result {
1245 Ok(_) => {
1246 }
1248 Err(_) => {
1249 }
1251 }
1252 }
1253
1254 #[test]
1255 fn test_memory_pool_negative_growth_factor() {
1256 let config = MemoryPoolConfig::new(1024).with_growth_factor(-0.5);
1258 let builder = BackendBuilder::new()
1259 .backend_type(BackendType::Cpu)
1260 .memory_pool(config);
1261
1262 let result = builder.build();
1264 match result {
1266 Ok(_) => {
1267 }
1269 Err(_) => {
1270 }
1272 }
1273 }
1274
1275 #[test]
1276 fn test_device_selector_with_conflicting_criteria() {
1277 use crate::device::{DeviceSelector, DeviceType};
1278
1279 let selector = DeviceSelector::new()
1281 .with_device_type(DeviceType::Cpu)
1282 .with_device_type(DeviceType::Cuda); let result = find_best_device(Some(selector));
1285 assert!(result.is_ok());
1287 }
1288
1289 #[test]
1290 fn test_backend_builder_cloning_with_modifications() {
1291 let original_builder = BackendBuilder::new()
1293 .backend_type(BackendType::Cpu)
1294 .num_threads(2);
1295
1296 let mut cloned_builder = original_builder.clone();
1297 cloned_builder = cloned_builder.num_threads(4);
1298
1299 let original_result = original_builder.build();
1301 let cloned_result = cloned_builder.build();
1302
1303 assert!(original_result.is_ok());
1304 assert!(cloned_result.is_ok());
1305 }
1306
1307 #[test]
1308 fn test_error_context_with_empty_strings() {
1309 use crate::error::ErrorContext;
1310
1311 let context = ErrorContext::new("")
1313 .with_backend("")
1314 .with_device("")
1315 .with_details("");
1316
1317 let formatted = context.format();
1318 assert!(!formatted.is_empty());
1320 }
1321
1322 #[test]
1323 fn test_error_context_with_null_characters() {
1324 use crate::error::ErrorContext;
1325
1326 let context = ErrorContext::new("op\0eration")
1328 .with_backend("back\0end")
1329 .with_device("dev\0ice")
1330 .with_details("deta\0ils");
1331
1332 let formatted = context.format();
1333 assert!(!formatted.is_empty());
1335 }
1336
1337 #[test]
1338 fn test_memory_manager_extreme_alignment() {
1339 let config = MemoryPoolConfig::new(1024).with_alignment(usize::MAX);
1341 let builder = BackendBuilder::new()
1342 .backend_type(BackendType::Cpu)
1343 .memory_pool(config);
1344
1345 let result = builder.build();
1347 match result {
1349 Ok(_) => {
1350 }
1352 Err(_) => {
1353 }
1355 }
1356 }
1357
1358 #[test]
1359 fn test_backend_resource_cleanup() {
1360 let backend = BackendBuilder::new()
1362 .backend_type(BackendType::Cpu)
1363 .build()
1364 .unwrap();
1365
1366 let _device = backend.default_device().unwrap();
1368 let _devices = backend.devices().unwrap();
1369
1370 drop(backend);
1372
1373 }
1376
1377 #[test]
1378 fn test_available_backends_consistency() {
1379 let backends1 = available_backends();
1381 let backends2 = available_backends();
1382
1383 assert_eq!(backends1, backends2);
1385
1386 assert!(backends1.contains(&BackendType::Cpu));
1388
1389 assert!(!backends1.contains(&BackendType::Auto));
1391 }
1392
1393 #[test]
1394 fn test_device_count_consistency() {
1395 for backend_type in available_backends() {
1397 let count1 = device_count(backend_type).unwrap();
1398 let count2 = device_count(backend_type).unwrap();
1399
1400 assert_eq!(
1401 count1, count2,
1402 "Device count should be consistent for {:?}",
1403 backend_type
1404 );
1405 }
1406 }
1407
1408 #[test]
1409 fn test_enumerate_devices_with_no_backends() {
1410 let devices = enumerate_all_devices().unwrap();
1413
1414 assert!(!devices.is_empty());
1416
1417 }
1420
1421 #[test]
1422 fn test_backend_capability_reporting() {
1423 let backend = BackendBuilder::new()
1425 .backend_type(BackendType::Cpu)
1426 .build()
1427 .unwrap();
1428
1429 let capabilities = backend.capabilities();
1430
1431 assert!(!capabilities.supported_dtypes.is_empty());
1433
1434 assert!(capabilities
1436 .supported_dtypes
1437 .contains(&torsh_core::DType::F32));
1438 assert!(capabilities
1439 .supported_dtypes
1440 .contains(&torsh_core::DType::F64));
1441 }
1442
1443 #[test]
1444 fn test_error_recovery_and_retry_logic() {
1445 let mut retry_count = 0;
1447 let max_retries = 3;
1448
1449 loop {
1450 let result = BackendBuilder::new()
1452 .backend_type(BackendType::Cpu)
1453 .num_threads(1) .build();
1455
1456 match result {
1457 Ok(_) => {
1458 break;
1460 }
1461 Err(e) => {
1462 retry_count += 1;
1463 if retry_count >= max_retries {
1464 let error_msg = e.to_string();
1466 assert!(!error_msg.is_empty());
1467 break;
1468 }
1469 std::thread::sleep(std::time::Duration::from_millis(10));
1471 }
1472 }
1473 }
1474 }
1475
1476 #[test]
1477 fn test_backend_performance_hints() {
1478 let backend = BackendBuilder::new()
1480 .backend_type(BackendType::Cpu)
1481 .build()
1482 .unwrap();
1483
1484 let hints = backend.performance_hints();
1485
1486 assert!(hints.optimal_batch_size > 0);
1488
1489 assert!(hints.optimal_batch_size <= 1024 * 1024); }
1492
1493 #[test]
1494 fn test_cross_backend_type_compatibility() {
1495 let cpu_result = BackendBuilder::new().backend_type(BackendType::Cpu).build();
1497
1498 assert!(cpu_result.is_ok());
1499
1500 #[cfg(feature = "cuda")]
1502 {
1503 let cuda_result = BackendBuilder::new()
1504 .backend_type(BackendType::Cuda)
1505 .build();
1506
1507 match cuda_result {
1509 Ok(_) => {
1510 }
1512 Err(_) => {
1513 }
1515 }
1516 }
1517 }
1518
1519 #[test]
1520 fn test_backend_state_isolation() {
1521 let backend1 = BackendBuilder::new()
1523 .backend_type(BackendType::Cpu)
1524 .num_threads(2)
1525 .build()
1526 .unwrap();
1527
1528 let backend2 = BackendBuilder::new()
1529 .backend_type(BackendType::Cpu)
1530 .num_threads(4)
1531 .build()
1532 .unwrap();
1533
1534 let device1 = backend1.default_device().unwrap();
1536 let device2 = backend2.default_device().unwrap();
1537
1538 assert!(!device1.name().is_empty());
1539 assert!(!device2.name().is_empty());
1540 }
1541
1542 #[test]
1543 fn test_profiling_enablement() {
1544 let backend = BackendBuilder::new()
1546 .backend_type(BackendType::Cpu)
1547 .enable_profiling(true)
1548 .build()
1549 .unwrap();
1550
1551 let device = backend.default_device().unwrap();
1553 assert!(!device.name().is_empty());
1554
1555 let backend_no_prof = BackendBuilder::new()
1557 .backend_type(BackendType::Cpu)
1558 .enable_profiling(false)
1559 .build()
1560 .unwrap();
1561
1562 let device_no_prof = backend_no_prof.default_device().unwrap();
1563 assert!(!device_no_prof.name().is_empty());
1564 }
1565
1566 #[test]
1569 #[ignore = "Requires CUDA hardware - run with --ignored flag"]
1570 fn test_cross_backend_validation_integration() {
1571 use crate::cross_backend_validation::{
1572 run_cross_backend_validation, CrossBackendValidator,
1573 };
1574
1575 let validator = CrossBackendValidator::new();
1577 assert!(!validator.available_backends().is_empty());
1578
1579 match validator.validate_device_creation() {
1583 Ok(()) => {} Err(e) => eprintln!("Device creation validation warning: {}", e),
1585 }
1586 match validator.validate_capabilities_consistency() {
1587 Ok(()) => {} Err(e) => eprintln!("Capabilities consistency validation warning: {}", e),
1589 }
1590
1591 match run_cross_backend_validation() {
1593 Ok(()) => {
1594 }
1596 Err(e) => {
1597 eprintln!("Cross-backend validation warning: {}", e);
1600 }
1601 }
1602 }
1603
1604 #[test]
1605 fn test_floating_point_comparison_utilities() {
1606 use crate::cross_backend_validation::{compare_f32_values, compare_f64_values};
1607
1608 assert!(compare_f32_values(1.0, 1.0, 1e-6));
1610 assert!(compare_f32_values(1.0, 1.0000005, 1e-6));
1611 assert!(!compare_f32_values(1.0, 1.1, 1e-6));
1612
1613 assert!(compare_f64_values(1.0, 1.0, 1e-11));
1614 assert!(compare_f64_values(1.0, 1.00000000001, 1.1e-11));
1615 assert!(!compare_f64_values(1.0, 1.1, 1e-11));
1616
1617 assert!(compare_f32_values(f32::NAN, f32::NAN, 1e-6));
1619 assert!(compare_f32_values(f32::INFINITY, f32::INFINITY, 1e-6));
1620 assert!(!compare_f32_values(f32::INFINITY, f32::NEG_INFINITY, 1e-6));
1621
1622 assert!(compare_f64_values(f64::NAN, f64::NAN, 1e-12));
1623 assert!(compare_f64_values(f64::INFINITY, f64::INFINITY, 1e-12));
1624 assert!(!compare_f64_values(f64::INFINITY, f64::NEG_INFINITY, 1e-12));
1625 }
1626
1627 #[test]
1630 fn test_hardware_optimization_integration() {
1631 use crate::hardware_optimization_tests::{
1632 run_lightweight_hardware_tests, HardwareOptimizationTester,
1633 };
1634
1635 let tester = HardwareOptimizationTester::new();
1637 assert!(tester.simd_tests_enabled);
1638 assert!(tester.platform_tests_enabled);
1639 assert!(!tester.performance_tests_enabled); match run_lightweight_hardware_tests() {
1643 Ok(()) => {
1644 }
1646 Err(e) => {
1647 eprintln!("Hardware optimization tests warning: {}", e);
1649 }
1650 }
1651 }
1652
1653 #[test]
1654 fn test_hardware_optimization_tester_configuration() {
1655 use crate::hardware_optimization_tests::HardwareOptimizationTester;
1656
1657 let mut tester = HardwareOptimizationTester::new();
1659
1660 tester.simd_tests_enabled = false;
1662 tester.platform_tests_enabled = true;
1663 tester.performance_tests_enabled = false;
1664
1665 assert!(!tester.simd_tests_enabled);
1667 assert!(tester.platform_tests_enabled);
1668 assert!(!tester.performance_tests_enabled);
1669 }
1670}