tensor_compute/gpu_store/
mod.rs

1use crate::gpu_internals::{GpuInfo, GpuInstance};
2use once_cell::sync::Lazy;
3use std::sync::RwLock;
4
5static DEVICES: Lazy<GpuStore> = Lazy::new(|| {
6    let (s, r) = std::sync::mpsc::channel();
7    std::thread::spawn(move || s.send(futures::executor::block_on(GpuStore::new())));
8    r.recv().unwrap()
9});
10
11pub struct GpuStore {
12    current: RwLock<usize>,
13    available_devices: Vec<GpuInstance>,
14}
15
16impl GpuStore {
17    pub fn get_default() -> &'static GpuInstance {
18        let current_idx = DEVICES.current.read().unwrap();
19        &DEVICES.available_devices[*current_idx]
20    }
21
22    pub fn get(gpu_info: &GpuInfo) -> &'static GpuInstance {
23        DEVICES
24            .available_devices
25            .iter()
26            .find(|dev| dev.info() == gpu_info)
27            .unwrap()
28    }
29
30    pub async fn select_gpu(gpu_info: &GpuInfo) {
31        let idx = DEVICES
32            .available_devices
33            .iter()
34            .position(|dev| dev.info() == gpu_info)
35            .unwrap();
36        *(&DEVICES.current).write().unwrap() = idx;
37    }
38
39    pub fn list_gpus() -> Vec<&'static GpuInfo> {
40        (&DEVICES)
41            .available_devices
42            .iter()
43            .map(|dev| dev.info())
44            .collect()
45    }
46
47    async fn new() -> Self {
48        let gpu_factory = crate::gpu_internals::gpu_factory::GpuFactory::new().await;
49
50        let gpu_list = gpu_factory.list_gpus().await;
51
52        let mut gpu_instances = vec![];
53        for gpu_info in &gpu_list {
54            gpu_instances.push(gpu_factory.request_gpu(&gpu_info).await);
55        }
56        assert!(!gpu_instances.is_empty(), "No GPU detected!");
57        Self {
58            current: RwLock::new(0),
59            available_devices: gpu_instances,
60        }
61    }
62}