tensor_compute/gpu_store/
mod.rs1use crate::gpu_internals::{GpuInfo, GpuInstance};
2use once_cell::sync::Lazy;
3use std::sync::RwLock;
4
5static DEVICES: Lazy<GpuStore> = Lazy::new(|| {
6 let (s, r) = std::sync::mpsc::channel();
7 std::thread::spawn(move || s.send(futures::executor::block_on(GpuStore::new())));
8 r.recv().unwrap()
9});
10
11pub struct GpuStore {
12 current: RwLock<usize>,
13 available_devices: Vec<GpuInstance>,
14}
15
16impl GpuStore {
17 pub fn get_default() -> &'static GpuInstance {
18 let current_idx = DEVICES.current.read().unwrap();
19 &DEVICES.available_devices[*current_idx]
20 }
21
22 pub fn get(gpu_info: &GpuInfo) -> &'static GpuInstance {
23 DEVICES
24 .available_devices
25 .iter()
26 .find(|dev| dev.info() == gpu_info)
27 .unwrap()
28 }
29
30 pub async fn select_gpu(gpu_info: &GpuInfo) {
31 let idx = DEVICES
32 .available_devices
33 .iter()
34 .position(|dev| dev.info() == gpu_info)
35 .unwrap();
36 *(&DEVICES.current).write().unwrap() = idx;
37 }
38
39 pub fn list_gpus() -> Vec<&'static GpuInfo> {
40 (&DEVICES)
41 .available_devices
42 .iter()
43 .map(|dev| dev.info())
44 .collect()
45 }
46
47 async fn new() -> Self {
48 let gpu_factory = crate::gpu_internals::gpu_factory::GpuFactory::new().await;
49
50 let gpu_list = gpu_factory.list_gpus().await;
51
52 let mut gpu_instances = vec![];
53 for gpu_info in &gpu_list {
54 gpu_instances.push(gpu_factory.request_gpu(&gpu_info).await);
55 }
56 assert!(!gpu_instances.is_empty(), "No GPU detected!");
57 Self {
58 current: RwLock::new(0),
59 available_devices: gpu_instances,
60 }
61 }
62}