tsai_compute/
lib.rs

1//! # tsai_compute
2//!
3//! A heterogeneous compute abstraction layer for tsai-rs.
4//!
5//! This crate provides a unified interface for compute operations across
6//! different hardware backends:
7//!
8//! - **CPU**: SIMD-optimized (AVX2, AVX-512, NEON) with optional NUMA awareness
9//! - **Metal**: Apple GPU via Metal framework (macOS/iOS)
10//! - **CUDA**: NVIDIA GPU via CUDA toolkit
11//! - **Vulkan**: Cross-platform GPU via Vulkan
12//! - **OpenCL**: Cross-platform GPU/CPU via OpenCL
13//! - **ROCm**: AMD GPU via ROCm/HIP
14//!
15//! ## Features
16//!
17//! - Runtime device discovery and selection
18//! - Unified buffer and memory management
19//! - Command encoding and synchronization
20//! - Workload-aware scheduling
21//! - Burn framework integration (optional)
22//!
23//! ## Quick Start
24//!
25//! ```rust,ignore
26//! use tsai_compute::{HardwareDiscovery, DevicePool};
27//!
28//! // Discover all available devices
29//! let pool = HardwareDiscovery::discover_all()?;
30//! pool.print_summary();
31//!
32//! // Get the best device
33//! let device = pool.best_device().expect("No devices available");
34//! println!("Using: {}", device.name());
35//! ```
36//!
37//! ## Feature Flags
38//!
39//! - `cpu` (default): CPU backend with SIMD
40//! - `numa`: NUMA-aware memory allocation
41//! - `cuda`: NVIDIA CUDA support
42//! - `metal`: Apple Metal support (macOS only)
43//! - `vulkan`: Vulkan compute support
44//! - `opencl`: OpenCL support
45//! - `rocm`: AMD ROCm support
46//! - `burn-bridge`: Burn framework integration
47
48#![warn(missing_docs)]
49#![warn(clippy::all)]
50
51pub mod backend;
52pub mod bridge;
53pub mod device;
54pub mod discovery;
55pub mod error;
56pub mod memory;
57pub mod scheduler;
58
59// Re-export commonly used types
60pub use backend::{CommandEncoder, ComputeBackend, Fence};
61pub use device::{
62    ComputeDevice, DeviceCapabilities, DeviceFeature, DeviceId, DevicePool, DeviceType,
63    SelectionStrategy, SimdLevel,
64};
65pub use discovery::{get_device_pool, get_discovery_time_us, refresh_device_pool, HardwareDiscovery};
66pub use error::{ComputeError, ComputeResult};
67pub use memory::{Buffer, BufferUsage, MemoryPool};
68pub use scheduler::{Priority, RoundRobinScheduler, Scheduler, SimpleScheduler, Workload, WorkloadScheduler};
69
70// Backend-specific re-exports
71pub use backend::cpu::{CpuBackend, CpuBuffer, CpuDevice};
72
73#[cfg(target_os = "macos")]
74pub use backend::metal::{MetalBackend, MetalBuffer, MetalDevice};
75
76#[cfg(feature = "cuda")]
77pub use backend::cuda::{CudaBackend, CudaBuffer, CudaDevice};
78
79#[cfg(feature = "vulkan")]
80pub use backend::vulkan::{VulkanBackend, VulkanBuffer, VulkanDevice};
81
82#[cfg(feature = "opencl")]
83pub use backend::opencl::{OpenClBackend, OpenClBuffer, OpenClDevice};
84
85#[cfg(feature = "rocm")]
86pub use backend::rocm::{RocmBackend, RocmBuffer, RocmDevice};
87
88/// Prelude module for convenient imports.
89pub mod prelude {
90    pub use crate::backend::{CommandEncoder, ComputeBackend, Fence};
91    pub use crate::device::{
92        ComputeDevice, DeviceCapabilities, DeviceFeature, DeviceId, DevicePool, DeviceType,
93    };
94    pub use crate::discovery::{get_device_pool, HardwareDiscovery};
95    pub use crate::error::{ComputeError, ComputeResult};
96    pub use crate::memory::{Buffer, BufferUsage};
97    pub use crate::scheduler::{Scheduler, Workload};
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_hardware_discovery() {
106        let pool = HardwareDiscovery::discover_all().unwrap();
107        assert!(pool.has_devices());
108
109        println!("\n=== tsai_compute Device Discovery ===\n");
110        pool.print_summary();
111    }
112
113    #[test]
114    fn test_cpu_backend() {
115        let devices = CpuBackend::enumerate_devices().unwrap();
116        assert!(!devices.is_empty());
117
118        let backend = CpuBackend::new(&devices[0]).unwrap();
119        let buffer = backend.allocate_buffer(1024, BufferUsage::HOST_VISIBLE).unwrap();
120        assert_eq!(buffer.size(), 1024);
121    }
122
123    #[test]
124    fn test_device_selection() {
125        let pool = HardwareDiscovery::discover_all().unwrap();
126
127        // Test best device selection
128        let best = pool.best_device();
129        assert!(best.is_some());
130
131        // Test device filtering
132        let cpus = pool.cpu_devices();
133        assert!(!cpus.is_empty());
134    }
135
136    #[test]
137    fn test_scheduler() {
138        let pool = HardwareDiscovery::discover_all().unwrap();
139        let scheduler = WorkloadScheduler::new();
140
141        let workload = Workload::new()
142            .with_flops(1_000_000)
143            .with_memory(1024 * 1024);
144
145        let device = scheduler.select_device(&pool, &workload).unwrap();
146        println!("Selected device: {:?}", device);
147    }
148}