1use crate::error::{ClusteringError, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum GpuBackend {
13 Cuda,
15 OpenCl,
17 Rocm,
19 OneApi,
21 Metal,
23 CpuFallback,
25}
26
27impl Default for GpuBackend {
28 fn default() -> Self {
29 GpuBackend::CpuFallback
30 }
31}
32
33impl std::fmt::Display for GpuBackend {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 GpuBackend::Cuda => write!(f, "CUDA"),
37 GpuBackend::OpenCl => write!(f, "OpenCL"),
38 GpuBackend::Rocm => write!(f, "ROCm"),
39 GpuBackend::OneApi => write!(f, "Intel OneAPI"),
40 GpuBackend::Metal => write!(f, "Apple Metal"),
41 GpuBackend::CpuFallback => write!(f, "CPU Fallback"),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct GpuDevice {
49 pub device_id: u32,
51 pub name: String,
53 pub total_memory: usize,
55 pub available_memory: usize,
57 pub compute_capability: String,
59 pub compute_units: u32,
61 pub backend: GpuBackend,
63 pub supports_double_precision: bool,
65}
66
67impl GpuDevice {
68 pub fn new(
70 device_id: u32,
71 name: String,
72 total_memory: usize,
73 available_memory: usize,
74 compute_capability: String,
75 compute_units: u32,
76 backend: GpuBackend,
77 supports_double_precision: bool,
78 ) -> Self {
79 Self {
80 device_id,
81 name,
82 total_memory,
83 available_memory,
84 compute_capability,
85 compute_units,
86 backend,
87 supports_double_precision,
88 }
89 }
90
91 pub fn memory_utilization(&self) -> f64 {
93 if self.total_memory == 0 {
94 0.0
95 } else {
96 100.0 * (1.0 - (self.available_memory as f64 / self.total_memory as f64))
97 }
98 }
99
100 pub fn is_suitable_for_double_precision(&self) -> bool {
102 self.supports_double_precision
103 }
104
105 pub fn get_device_score(&self) -> f64 {
107 let memory_score = self.available_memory as f64 / 1_000_000_000.0; let compute_score = self.compute_units as f64;
109 let precision_bonus = if self.supports_double_precision {
110 1.5
111 } else {
112 1.0
113 };
114
115 (memory_score + compute_score) * precision_bonus
116 }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
121pub enum DeviceSelection {
122 First,
124 MostMemory,
126 HighestCompute,
128 Specific(u32),
130 Auto,
132 Fastest,
134}
135
136impl Default for DeviceSelection {
137 fn default() -> Self {
138 DeviceSelection::Auto
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct GpuConfig {
145 pub preferred_backend: GpuBackend,
147 pub device_selection: DeviceSelection,
149 pub auto_fallback: bool,
151 pub memory_pool_size: Option<usize>,
153 pub optimize_memory: bool,
155 pub backend_options: HashMap<String, String>,
157}
158
159impl Default for GpuConfig {
160 fn default() -> Self {
161 Self {
162 preferred_backend: GpuBackend::CpuFallback,
163 device_selection: DeviceSelection::Auto,
164 auto_fallback: true,
165 memory_pool_size: None,
166 optimize_memory: true,
167 backend_options: HashMap::new(),
168 }
169 }
170}
171
172impl GpuConfig {
173 pub fn new(backend: GpuBackend) -> Self {
175 Self {
176 preferred_backend: backend,
177 ..Default::default()
178 }
179 }
180
181 pub fn with_device_selection(mut self, strategy: DeviceSelection) -> Self {
183 self.device_selection = strategy;
184 self
185 }
186
187 pub fn with_memory_pool_size(mut self, size: usize) -> Self {
189 self.memory_pool_size = Some(size);
190 self
191 }
192
193 pub fn without_fallback(mut self) -> Self {
195 self.auto_fallback = false;
196 self
197 }
198
199 pub fn with_backend_option(mut self, key: String, value: String) -> Self {
201 self.backend_options.insert(key, value);
202 self
203 }
204
205 pub fn cuda() -> Self {
207 Self::new(GpuBackend::Cuda)
208 }
209
210 pub fn opencl() -> Self {
212 Self::new(GpuBackend::OpenCl)
213 }
214
215 pub fn rocm() -> Self {
217 Self::new(GpuBackend::Rocm)
218 }
219
220 pub fn metal() -> Self {
222 Self::new(GpuBackend::Metal)
223 }
224
225 pub fn validate(&self) -> Result<()> {
227 if let DeviceSelection::Specific(id) = self.device_selection {
228 if id > 64 {
229 return Err(ClusteringError::InvalidInput(
230 "Device ID too high".to_string(),
231 ));
232 }
233 }
234
235 if let Some(pool_size) = self.memory_pool_size {
236 if pool_size < 1024 * 1024 {
237 return Err(ClusteringError::InvalidInput(
238 "Memory pool size too small (minimum 1MB)".to_string(),
239 ));
240 }
241 }
242
243 Ok(())
244 }
245}
246
247#[derive(Debug)]
249pub struct GpuContext {
250 pub device: GpuDevice,
252 pub config: GpuConfig,
254 pub gpu_available: bool,
256 pub backend_context: BackendContext,
258}
259
260impl GpuContext {
261 pub fn new(device: GpuDevice, config: GpuConfig) -> Result<Self> {
263 config.validate()?;
264
265 let gpu_available = Self::check_gpu_availability(&device, &config);
266 let backend_context = BackendContext::new(&device.backend)?;
267
268 Ok(Self {
269 device,
270 config,
271 gpu_available,
272 backend_context,
273 })
274 }
275
276 fn check_gpu_availability(device: &GpuDevice, config: &GpuConfig) -> bool {
278 match (device.backend, config.preferred_backend) {
280 (GpuBackend::CpuFallback, _) => false,
281 (backend1, backend2) if backend1 == backend2 => true,
282 _ => config.auto_fallback,
283 }
284 }
285
286 pub fn effective_backend(&self) -> GpuBackend {
288 if self.gpu_available {
289 self.device.backend
290 } else {
291 GpuBackend::CpuFallback
292 }
293 }
294
295 pub fn is_gpu_accelerated(&self) -> bool {
297 self.gpu_available && self.device.backend != GpuBackend::CpuFallback
298 }
299
300 pub fn memory_info(&self) -> (usize, usize) {
302 (self.device.total_memory, self.device.available_memory)
303 }
304}
305
306#[derive(Debug)]
308pub enum BackendContext {
309 Cuda {
311 context_handle: u64,
313 stream_handle: u64,
315 },
316 OpenCl {
318 context_handle: u64,
320 queue_handle: u64,
322 },
323 Rocm {
325 context_handle: u64,
327 },
328 OneApi {
330 context_handle: u64,
332 },
333 Metal {
335 device_handle: u64,
337 queue_handle: u64,
339 },
340 CpuFallback,
342}
343
344impl BackendContext {
345 pub fn new(backend: &GpuBackend) -> Result<Self> {
347 match backend {
348 GpuBackend::Cuda => Ok(BackendContext::Cuda {
349 context_handle: 0, stream_handle: 0,
351 }),
352 GpuBackend::OpenCl => Ok(BackendContext::OpenCl {
353 context_handle: 0, queue_handle: 0,
355 }),
356 GpuBackend::Rocm => Ok(BackendContext::Rocm {
357 context_handle: 0, }),
359 GpuBackend::OneApi => Ok(BackendContext::OneApi {
360 context_handle: 0, }),
362 GpuBackend::Metal => Ok(BackendContext::Metal {
363 device_handle: 0, queue_handle: 0,
365 }),
366 GpuBackend::CpuFallback => Ok(BackendContext::CpuFallback),
367 }
368 }
369
370 pub fn is_valid(&self) -> bool {
372 match self {
373 BackendContext::CpuFallback => true,
374 _ => true, }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_gpu_device_creation() {
385 let device = GpuDevice::new(
386 0,
387 "Test GPU".to_string(),
388 8_000_000_000, 6_000_000_000, "7.5".to_string(),
391 2048,
392 GpuBackend::Cuda,
393 true,
394 );
395
396 assert_eq!(device.device_id, 0);
397 assert_eq!(device.name, "Test GPU");
398 assert_eq!(device.memory_utilization(), 25.0); assert!(device.is_suitable_for_double_precision());
400 }
401
402 #[test]
403 fn test_gpu_config_validation() {
404 let config = GpuConfig::default();
405 assert!(config.validate().is_ok());
406
407 let invalid_config = GpuConfig::default().with_memory_pool_size(1024); assert!(invalid_config.validate().is_err());
409 }
410
411 #[test]
412 fn test_device_selection_strategies() {
413 assert_eq!(DeviceSelection::default(), DeviceSelection::Auto);
414
415 let specific = DeviceSelection::Specific(0);
416 if let DeviceSelection::Specific(id) = specific {
417 assert_eq!(id, 0);
418 }
419 }
420
421 #[test]
422 fn test_backend_context_creation() {
423 let cuda_context = BackendContext::new(&GpuBackend::Cuda).expect("Operation failed");
424 assert!(cuda_context.is_valid());
425
426 let cpu_context = BackendContext::new(&GpuBackend::CpuFallback).expect("Operation failed");
427 assert!(cpu_context.is_valid());
428 }
429}