Skip to main content

unlab_gpu/
backend.rs

1//
2// Copyright (c) 2025-2026 Ɓukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8//! A backend module.
9use std::fs::File;
10use std::io::ErrorKind;
11use std::io::Read;
12use std::path::Path;
13#[cfg(any(feature = "opencl", feature = "cuda"))]
14use std::sync::Arc;
15#[cfg(feature = "opencl")]
16use crate::matrix;
17#[cfg(feature = "opencl")]
18use crate::matrix::opencl::CL_DEVICE_TYPE_ALL;
19#[cfg(feature = "opencl")]
20use crate::matrix::opencl::ClBackend;
21#[cfg(feature = "opencl")]
22use crate::matrix::opencl::Context;
23#[cfg(feature = "opencl")]
24use crate::matrix::opencl::Device;
25#[cfg(feature = "opencl")]
26use crate::matrix::opencl::get_platforms;
27#[cfg(feature = "cuda")]
28use crate::matrix::cuda::CudaBackend;
29#[cfg(any(feature = "opencl", feature = "cuda"))]
30use crate::matrix::set_default_backend;
31use crate::matrix::unset_default_backend;
32use crate::serde::Deserialize;
33use crate::serde::Serialize;
34use crate::toml;
35use crate::error::*;
36
37/// A backend enumeration.
38#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize, Deserialize)]
39pub enum Backend
40{
41    /// An OpenCL backend.
42    #[serde(rename = "OpenCL")]
43    OpenCl,
44    /// A CUDA backend.
45    #[serde(rename = "CUDA")]
46    Cuda,
47}
48
49/// A structure of backend configuration.
50#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
51pub struct BackendConfig
52{
53    /// A backend.
54    pub backend: Option<Backend>,
55    /// An ordinal number for the CUDA backend.
56    pub ordinal: Option<usize>,
57    /// A platform index for the OpenCL backend.
58    pub platform: Option<usize>,
59    /// A device index for the OpenCL backend.
60    pub device: Option<usize>,
61    /// If this field is `true`, the CUDA backend uses the cuBLAS library.
62    pub cublas: Option<bool>,
63    /// If this field is `true`, the CUDA backend uses the mma instruction.
64    pub mma: Option<bool>,
65}
66
67impl BackendConfig
68{
69    /// Reads a backend configuration from the reader.
70    pub fn read(r: &mut dyn Read) -> Result<Self>
71    {
72        let mut s = String::new();
73        match r.read_to_string(&mut s) {
74            Ok(_) => {
75                match toml::from_str(s.as_str()) {
76                    Ok(config) => Ok(config),
77                    Err(err) => Err(Error::TomlDe(err)),
78                }
79            },
80            Err(err) => Err(Error::Io(err)),
81        }
82    }
83
84    /// Loads a backend configuration from the file.
85    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self>
86    {
87        match File::open(path) {
88            Ok(mut file) => Self::read(&mut file),
89            Err(err) => Err(Error::Io(err)),
90        }
91    }
92
93    /// Loads a backend configuration from the file if the file exists, otherwise this method
94    /// returns `None`.
95    pub fn load_opt<P: AsRef<Path>>(path: P) -> Result<Option<Self>>
96    {
97        match File::open(path) {
98            Ok(mut file) => Ok(Some(Self::read(&mut file)?)),
99            Err(err) if err.kind() == ErrorKind::NotFound => Ok(None),
100            Err(err) => Err(Error::Io(err)),
101        }
102    }
103}
104
105#[cfg(feature = "opencl")]
106fn initialize_opencl_backend(platform_idx: usize, device_idx: usize) -> Result<()>
107{
108    let platforms = match get_platforms() {
109        Ok(tmp_platforms) => tmp_platforms,
110        Err(err) => return Err(Error::Matrix(matrix::Error::OpenCl(err))),
111    };
112    let platform = match platforms.get(platform_idx) {
113        Some(tmp_platform) => tmp_platform,
114        None => return Err(Error::Matrix(matrix::Error::NoPlatform)),
115    };
116    let device_ids = match platform.get_devices(CL_DEVICE_TYPE_ALL) {
117        Ok(tmp_device_ids) => tmp_device_ids,
118        Err(err) => return Err(Error::Matrix(matrix::Error::OpenCl(err))),
119    };
120    let device = match device_ids.get(device_idx) {
121        Some(device_id) => Device::new(*device_id),
122        None => return Err(Error::Matrix(matrix::Error::NoDevice)),
123    };
124    let context = match Context::from_device(&device) {
125        Ok(tmp_context) => tmp_context,
126        Err(err) => return Err(Error::Matrix(matrix::Error::OpenCl(err))),
127    };
128    match ClBackend::new_with_context(context) {
129        Ok(backend) => {
130            match set_default_backend(Arc::new(backend)) {
131                Ok(()) => Ok(()),
132                Err(err) => Err(Error::Matrix(err)),
133            }
134        },
135        Err(err) => Err(Error::Matrix(err)),
136    }
137}
138
139#[cfg(not(feature = "opencl"))]
140fn initialize_opencl_backend(_platform_idx: usize, _device_idx: usize) -> Result<()>
141{ Err(Error::NoOpenClBackend) }
142
143#[cfg(feature = "cuda")]
144fn initialize_cuda_backend(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<()>
145{
146    match CudaBackend::new_with_ordinal_and_flags(ordinal, is_cublas, is_mma) {
147        Ok(backend) => {
148            match set_default_backend(Arc::new(backend)) {
149                Ok(()) => Ok(()),
150                Err(err) => Err(Error::Matrix(err)),
151            }
152        },
153        Err(err) => Err(Error::Matrix(err)),        
154    }
155}
156
157#[cfg(not(feature = "cuda"))]
158fn initialize_cuda_backend(_ordinal: usize, _is_cublas: bool, _is_mma: bool) -> Result<()>
159{ Err(Error::NoCudaBackend) }
160
161/// Initializes a backend for matrices with the backend configuration.
162///
163/// If the backend configuration isn't passed, this method uses the default field values of
164/// backend configuration.
165pub fn initialize_backend_with_config(config: &Option<BackendConfig>) -> Result<()>
166{
167    #[cfg(feature = "cuda")]
168    let mut backend = Backend::Cuda;
169    #[cfg(not(feature = "cuda"))]
170    let mut backend = Backend::OpenCl;
171    let mut ordinal = 0usize;
172    let mut platform_idx = 0usize;
173    let mut device_idx = 0usize;
174    let mut is_cublas = true;
175    let mut is_mma = false;
176    match config {
177        Some(config) => {
178            backend = config.backend.unwrap_or(backend);
179            ordinal = config.ordinal.unwrap_or(ordinal);
180            platform_idx = config.platform.unwrap_or(platform_idx);
181            device_idx = config.device.unwrap_or(device_idx);
182            is_cublas = config.cublas.unwrap_or(is_cublas);
183            is_mma = config.mma.unwrap_or(is_mma);
184        },
185        None => (),
186    }
187    match backend {
188        Backend::OpenCl => initialize_opencl_backend(platform_idx, device_idx),
189        Backend::Cuda => initialize_cuda_backend(ordinal, is_cublas, is_mma),
190    }
191}
192
193/// Initializes a backend for matrices with the file of backend configuration.
194///
195/// If the file of backend configuration doesn't exist, this method uses the default field values
196/// of backend configuration.
197pub fn initialize_backend<P: AsRef<Path>>(path: P) -> Result<()>
198{
199    let config = BackendConfig::load_opt(path)?;
200    initialize_backend_with_config(&config)
201}
202
203/// Finalizes a backend for matrices.
204pub fn finalize_backend() -> Result<()>
205{
206    match unset_default_backend() {
207        Ok(()) => Ok(()),
208        Err(err) => Err(Error::Matrix(err)),        
209    }
210}
211
212#[cfg(test)]
213mod tests;