unmtx_gpu/
cuda.rs

1//
2// Copyright (c) 2025 Ɓukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8//! A module that contains a CUDA backend.
9use std::default::Default;
10use std::ffi::c_int;
11use std::ffi::c_void;
12use std::sync::Arc;
13use std::sync::Mutex;
14use crate::Backend;
15use crate::BackendArray;
16use crate::Error;
17use crate::Result;
18use crate::mutex_lock;
19
20pub use cudarc::cublas::result::CublasError;
21pub use cudarc::driver::DriverError;
22
23use cudarc::cublas::result::sgemm;
24use cudarc::cublas::sys::cublasOperation_t;
25use cudarc::cublas::CudaBlas;
26use cudarc::driver::sys::CUdeviceptr;
27use cudarc::driver::CudaDevice;
28use cudarc::driver::CudaFunction;
29use cudarc::driver::CudaSlice;
30use cudarc::driver::DeviceRepr;
31use cudarc::driver::DevicePtr;
32use cudarc::driver::DevicePtrMut;
33use cudarc::driver::LaunchAsync;
34use cudarc::driver::LaunchConfig;
35use cudarc::nvrtc::CompileError;
36use cudarc::nvrtc::CompileOptions;
37use cudarc::nvrtc::compile_ptx_with_opts;
38
39const SOURCE: &'static str = include_str!("cuda.cu");
40
41const KERNELS: &'static [&'static str] = &[
42    "transpose_a",
43    "add_a_b",
44    "add_at_b",
45    "add_a_bt",
46    "add_at_bt",
47    "sub_a_b",
48    "sub_at_b",
49    "sub_a_bt",
50    "sub_at_bt",
51    "mul_a_b",
52    "mul_at_b",
53    "mul_a_bt",
54    "mul_at_bt",
55    "mul_a_b_for_elems",
56    "mul_at_b_for_elems",
57    "mul_a_bt_for_elems",
58    "mul_at_bt_for_elems",
59    "div_a_b_for_elems",
60    "div_at_b_for_elems",
61    "div_a_bt_for_elems",
62    "div_at_bt_for_elems",
63    "add_a_b_for_scalar",
64    "add_at_b_for_scalar",
65    "sub_a_b_for_scalar",
66    "sub_at_b_for_scalar",
67    "rsub_a_b_for_scalar",
68    "rsub_at_b_for_scalar",
69    "mul_a_b_for_scalar",
70    "mul_at_b_for_scalar",
71    "div_a_b_for_scalar",
72    "div_at_b_for_scalar",
73    "rdiv_a_b_for_scalar",
74    "rdiv_at_b_for_scalar",
75    "sigmoid_a",
76    "sigmoid_at",
77    "tanh_a",
78    "tanh_at",
79    "swish_a",
80    "swish_at",
81    "softmax_a",
82    "softmax_at",
83    "repeat_col_a",
84    "repeat_row_a"
85];
86
87/// A structure of CUDA backend array.
88///
89/// This structure contains the reference to the device memory.
90#[derive(Debug)]
91pub struct CudaBackendArray
92{
93    slice: Arc<Mutex<CudaSlice<f32>>>,
94    len: usize,
95}
96
97struct CudaInnerBackend
98{
99    device: Arc<CudaDevice>,
100    cublas: Option<CudaBlas>,
101}
102
103/// A structure of CUDA backend.
104pub struct CudaBackend
105{
106    inner: Mutex<CudaInnerBackend>,
107    has_cublas: bool,
108    has_mma: bool,
109}
110
111fn preferred_launch_config(n: usize, m: usize, is_mul: bool, is_mma: bool) -> LaunchConfig
112{
113    if m == 1 && !is_mul {
114        let n2 = ((n + 1023) / 1024) as u32;
115        LaunchConfig {
116            grid_dim: (n2, 1, 1),
117            block_dim: (1024, 1, 1),
118            shared_mem_bytes: 0,
119        }
120    } else if n == 1 && !is_mul {
121        let m2 = ((m + 1023) / 1024) as u32;
122        LaunchConfig {
123            grid_dim: (1, m2, 1),
124            block_dim: (1, 1024, 1),
125            shared_mem_bytes: 0,
126        }
127    } else if is_mul {
128        if is_mma {
129            let n2 = ((n + 63) / 64) as u32;
130            let m2 = ((m + 63) / 64) as u32;
131            LaunchConfig {
132                grid_dim: (n2, m2, 1),
133                block_dim: (1024, 1, 1),
134                shared_mem_bytes: 0,
135            }
136        } else {
137            let n2 = (((n + 3) / 4 + 15) / 16) as u32;
138            let m2 = (((m + 3) / 4 + 15) / 16) as u32;
139            LaunchConfig {
140                grid_dim: (n2, m2, 1),
141                block_dim: (16, 16, 1),
142                shared_mem_bytes: 0,
143            }
144        }
145    } else {
146        let n2 = ((n + 31) / 32) as u32;
147        let m2 = ((m + 31) / 32) as u32;
148        LaunchConfig {
149            grid_dim: (n2, m2, 1),
150            block_dim: (32, 32, 1),
151            shared_mem_bytes: 0,
152        }
153    }
154}
155
156impl CudaBackend
157{
158    /// Creates a CUDA backend for a first device.
159    pub fn new() -> Result<CudaBackend>
160    {
161        if cfg!(feature = "default_cublas") {
162            Self::new_with_ordinal_and_flags(0, true, false)
163        } else if cfg!(feature = "default_mma") {
164            Self::new_with_ordinal_and_flags(0, false, true)
165        } else {
166            Self::new_with_ordinal_and_flags(0, false, false)
167        }
168    }
169    
170    /// Creates a CUDA backend with the ordinal number and the flags.
171    ///
172    /// This method takes the following flags:
173    ///
174    /// - `is_cublas` - use the cuBLAS library to multiplication of matrices
175    /// - `is_mma` - use the mma instruction to multiplication of matrices
176    pub fn new_with_ordinal_and_flags(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<CudaBackend>
177    {
178        let device = match CudaDevice::new(ordinal) {
179            Ok(tmp_device) => tmp_device,
180            Err(err) => return Err(Error::Cuda(err)),
181        };
182        let mut options: CompileOptions = Default::default();
183        if is_mma {
184            options.options = vec![String::from("-DUNMTX_GPU_MMA=1")];
185            options.arch = Some("sm_80");
186        }
187        let ptx = match compile_ptx_with_opts(SOURCE, options) {
188            Ok(tmp_ptx) => tmp_ptx,
189            Err(CompileError::CompileError { log, .. }) => return Err(Error::Compilation(log.as_c_str().to_string_lossy().into_owned())),
190            Err(err) => return Err(Error::Compilation(format!("{}", err))),
191        };
192        match device.load_ptx(ptx, "unmtx_gpu", KERNELS) {
193            Ok(()) => (),
194            Err(err) => return Err(Error::Cuda(err)),
195        }
196        let cublas = if is_cublas {
197            match CudaBlas::new(device.clone()) {
198                Ok(tmp_cublas) => Some(tmp_cublas),
199                Err(err) => return Err(Error::Cublas(err)),
200            }
201        } else {
202            None
203        };
204        Ok(CudaBackend { inner: Mutex::new(CudaInnerBackend { device, cublas, }), has_cublas: is_cublas, has_mma: is_mma, })
205    }
206    
207    pub fn has_cublas(&self) -> bool
208    { self.has_cublas }
209    
210    fn check_and_launch2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
211        where F: FnOnce(&CudaBackendArray, &CudaBackendArray) -> Result<()>,
212            G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void) -> Result<()>
213    {
214        #[allow(unreachable_patterns)]
215        match (a, b) {
216            (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
217                f(a2, b2)?;
218                let inner_g = mutex_lock(&self.inner)?;
219                let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
220                    Some(tmp_kernel) => tmp_kernel,
221                    None => return Err(Error::NoKernel(String::from(kernel_name))),
222                };
223                if !Arc::ptr_eq(&a2.slice, &b2.slice) {
224                    let a_slice_g = mutex_lock(&a2.slice)?;
225                    let mut b_slice_g = mutex_lock(&b2.slice)?;
226                    g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?;
227                } else {
228                    let mut a_slice_g = mutex_lock(&a2.slice)?;
229                    g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?;
230                }
231                match inner_g.device.synchronize() {
232                    Ok(()) => (),
233                    Err(err) => return Err(Error::Cuda(err)),
234                }
235                Ok(())
236            },
237            _ => Err(Error::InvalidBackendArray),
238        }
239    }
240
241    fn check_and_launch3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
242        where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
243            G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void, *mut c_void) -> Result<()>
244    {
245        #[allow(unreachable_patterns)]
246        match (a, b, c) {
247            (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
248                f(a2, b2, c2)?;
249                let inner_g = mutex_lock(&self.inner)?;
250                let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
251                    Some(tmp_kernel) => tmp_kernel,
252                    None => return Err(Error::NoKernel(String::from(kernel_name))),
253                };
254                match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
255                    (false, false, false) => {
256                        let a_slice_g = mutex_lock(&a2.slice)?;
257                        let b_slice_g = mutex_lock(&b2.slice)?;
258                        let mut c_slice_g = mutex_lock(&c2.slice)?;
259                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
260                    },
261                    (true, false, false) => {
262                        let a_slice_g = mutex_lock(&a2.slice)?;
263                        let mut c_slice_g = mutex_lock(&c2.slice)?;
264                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*a_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
265                    },
266                    (false, true, false) => {
267                        let mut a_slice_g = mutex_lock(&a2.slice)?;
268                        let b_slice_g = mutex_lock(&b2.slice)?;
269                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
270                    },
271                    (false, false, true) => {
272                        let a_slice_g = mutex_lock(&a2.slice)?;
273                        let mut b_slice_g = mutex_lock(&b2.slice)?;
274                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?
275                    },
276                    _ => {
277                        let mut a_slice_g = mutex_lock(&a2.slice)?;
278                        g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
279                    },
280                }
281                match inner_g.device.synchronize() {
282                    Ok(()) => (),
283                    Err(err) => return Err(Error::Cuda(err)),
284                }
285                Ok(())
286            },
287            _ => Err(Error::InvalidBackendArray),
288        }
289    }    
290
291    fn check_and_launch_cublas3<F, G>(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
292        where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
293            G: FnOnce(&CudaInnerBackend, CUdeviceptr, CUdeviceptr, CUdeviceptr) -> Result<()>
294    {
295        #[allow(unreachable_patterns)]
296        match (a, b, c) {
297            (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
298                f(a2, b2, c2)?;
299                let inner_g = mutex_lock(&self.inner)?;
300                match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
301                    (false, false, false) => {
302                        let a_slice_g = mutex_lock(&a2.slice)?;
303                        let b_slice_g = mutex_lock(&b2.slice)?;
304                        let mut c_slice_g = mutex_lock(&c2.slice)?;
305                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
306                        let b_device_ptr = *(&(*b_slice_g)).device_ptr();
307                        let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
308                        g(&*inner_g, a_device_ptr, b_device_ptr, c_device_ptr)?
309                    },
310                    (true, false, false) => {
311                        let a_slice_g = mutex_lock(&a2.slice)?;
312                        let mut c_slice_g = mutex_lock(&c2.slice)?;
313                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
314                        let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
315                        g(&*inner_g, a_device_ptr, a_device_ptr, c_device_ptr)?
316                    },
317                    (false, true, false) => {
318                        let mut a_slice_g = mutex_lock(&a2.slice)?;
319                        let b_slice_g = mutex_lock(&b2.slice)?;
320                        let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
321                        let b_device_ptr = *(&(*b_slice_g)).device_ptr();
322                        g(&*inner_g, a_device_ptr, b_device_ptr, a_device_ptr)?
323                    },
324                    (false, false, true) => {
325                        let a_slice_g = mutex_lock(&a2.slice)?;
326                        let mut b_slice_g = mutex_lock(&b2.slice)?;
327                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
328                        let b_device_ptr = *(&mut (*b_slice_g)).device_ptr_mut();
329                        g(&*inner_g, a_device_ptr, b_device_ptr, b_device_ptr)?
330                    },
331                    _ => {
332                        let mut a_slice_g = mutex_lock(&a2.slice)?;
333                        let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
334                        g(&*inner_g, a_device_ptr, a_device_ptr, a_device_ptr)?
335                    },
336                }
337                match inner_g.device.synchronize() {
338                    Ok(()) => (),
339                    Err(err) => return Err(Error::Cuda(err)),
340                }
341                Ok(())
342            },
343            _ => Err(Error::InvalidBackendArray),
344        }
345    }
346    
347    fn check_and_launch_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
348    {
349        let is_mma = self.has_mma;
350        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
351                if a2.len != n * m {
352                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
353                }
354                if b2.len != n * m {
355                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
356                }
357                Ok(())
358        }, |_, kernel, a_param, b_param| {
359                let config = preferred_launch_config(n, m, false, is_mma);
360                let mut params = vec![
361                    a_param,
362                    b_param,
363                    n.as_kernel_param(),
364                    m.as_kernel_param()
365                ];
366                unsafe {
367                    match kernel.launch(config, &mut params) {
368                        Ok(()) => Ok(()),
369                        Err(err) => Err(Error::Cuda(err)),
370                    }
371                }
372        })
373    }
374
375    fn check_and_launch_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
376    {
377        let is_mma = self.has_mma;
378        self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
379                if a2.len != n * m {
380                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
381                }
382                if b2.len != n * m {
383                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
384                }
385                if c2.len != n * m {
386                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
387                }
388                Ok(())
389        }, |_, kernel, a_param, b_param, c_param| {
390                let config = preferred_launch_config(n, m, false, is_mma);
391                let mut params = vec![
392                    a_param,
393                    b_param,
394                    c_param,
395                    n.as_kernel_param(),
396                    m.as_kernel_param()
397                ];
398                unsafe {
399                    match kernel.launch(config, &mut params) {
400                        Ok(()) => Ok(()),
401                        Err(err) => Err(Error::Cuda(err)),
402                    }
403                }
404        })
405    }
406
407    fn check_and_launch_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
408    {
409        let is_mma = self.has_mma;
410        self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
411                if a2.len != n * l {
412                    return Err(Error::BackendArrayElemCount(a2.len, n * l));
413                }
414                if b2.len != l * m {
415                    return Err(Error::BackendArrayElemCount(b2.len, l * m));
416                }
417                if c2.len != n * m {
418                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
419                }
420                Ok(())
421        }, |_, kernel, a_param, b_param, c_param| {
422                let config = preferred_launch_config(n, m, true, is_mma);
423                let mut params = vec![
424                    a_param,
425                    b_param,
426                    c_param,
427                    n.as_kernel_param(),
428                    m.as_kernel_param(),
429                    l.as_kernel_param()
430                ];
431                unsafe {
432                    match kernel.launch(config, &mut params) {
433                        Ok(()) => Ok(()),
434                        Err(err) => Err(Error::Cuda(err)),
435                    }
436                }
437        })
438    }
439
440    fn check_and_launch_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
441    {
442        let is_mma = self.has_mma;
443        self.check_and_launch2(kernel_name, a, c, |a2, c2| {
444                if a2.len != n * m  {
445                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
446                }
447                if c2.len != n * m {
448                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
449                }
450                Ok(())
451        }, |_, kernel, a_param, c_param| {
452                let config = preferred_launch_config(n, m, false, is_mma);
453                let mut params = vec![
454                    a_param,
455                    b.as_kernel_param(),
456                    c_param,
457                    n.as_kernel_param(),
458                    m.as_kernel_param()
459                ];
460                unsafe {
461                    match kernel.launch(config, &mut params) {
462                        Ok(()) => Ok(()),
463                        Err(err) => Err(Error::Cuda(err)),
464                    }
465                }
466        })
467    }
468
469    fn check_and_launch_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
470    {
471        let is_mma = self.has_mma;
472        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
473                if a2.len != n * m {
474                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
475                }
476                if b2.len != n * m {
477                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
478                }
479                Ok(())
480        }, |_, kernel, a_param, b_param| {
481                let config = preferred_launch_config(n, m, false, is_mma);
482                let mut params = vec![
483                    a_param,
484                    b_param,
485                    n.as_kernel_param(),
486                    m.as_kernel_param(),
487                    ((config.block_dim.1) as usize).as_kernel_param(),
488                    ((config.block_dim.0) as usize).as_kernel_param()
489                ];
490                unsafe {
491                    match kernel.launch(config, &mut params) {
492                        Ok(()) => Ok(()),
493                        Err(err) => Err(Error::Cuda(err)),
494                    }
495                }
496        })
497    }
498
499    fn check_and_launch_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
500    {
501        let is_mma = self.has_mma;
502        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
503                if a2.len != n {
504                    return Err(Error::BackendArrayElemCount(a2.len, n));
505                }
506                if b2.len != n * m {
507                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
508                }
509                Ok(())
510        }, |_, kernel, a_param, b_param| {
511                let config = preferred_launch_config(n, m, false, is_mma);
512                let mut params = vec![
513                    a_param,
514                    b_param,
515                    n.as_kernel_param(),
516                    m.as_kernel_param()
517                ];
518                unsafe {
519                    match kernel.launch(config, &mut params) {
520                        Ok(()) => Ok(()),
521                        Err(err) => Err(Error::Cuda(err)),
522                    }
523                }
524        })
525    }
526
527    fn check_and_launch_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
528    {
529        let is_mma = self.has_mma;
530        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
531                if a2.len != m {
532                    return Err(Error::BackendArrayElemCount(a2.len, m));
533                }
534                if b2.len != n * m {
535                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
536                }
537                Ok(())
538        }, |_, kernel, a_param, b_param| {
539                let config = preferred_launch_config(n, m, false, is_mma);
540                let mut params = vec![
541                    a_param,
542                    b_param,
543                    n.as_kernel_param(),
544                    m.as_kernel_param()
545                ];
546                unsafe {
547                    match kernel.launch(config, &mut params) {
548                        Ok(()) => Ok(()),
549                        Err(err) => Err(Error::Cuda(err)),
550                    }
551                }
552        })
553    }    
554    
555    fn check_and_launch_cublas_for_mul(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize, is_trans_a: bool, is_trans_b: bool) -> Result<()>
556    {
557        self.check_and_launch_cublas3(a, b, c, |a2, b2, c2| {
558                if a2.len != n * l {
559                    return Err(Error::BackendArrayElemCount(a2.len, n * l));
560                }
561                if b2.len != l * m {
562                    return Err(Error::BackendArrayElemCount(b2.len, l * m));
563                }
564                if c2.len != n * m {
565                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
566                }
567                Ok(())
568        }, |inner, a_device_ptr, b_device_ptr, c_device_ptr| {
569                unsafe {
570                    match &inner.cublas {
571                        Some(cublas) => {
572                            let (transa, lda) = if is_trans_a {
573                                (cublasOperation_t::CUBLAS_OP_T, n as c_int)
574                            } else {
575                                (cublasOperation_t::CUBLAS_OP_N, l as c_int)
576                            };
577                            let (transb, ldb) = if is_trans_b {
578                                (cublasOperation_t::CUBLAS_OP_T, l as c_int)
579                            } else {
580                                (cublasOperation_t::CUBLAS_OP_N, m as c_int)
581                            };
582                            let alpha = 1.0f32;
583                            let beta = 0.0f32;
584                            let res = sgemm(*cublas.handle(),
585                                transb, transa,
586                                m as c_int, n as c_int, l as c_int,
587                                (&alpha) as *const _,
588                                b_device_ptr as *const _, ldb,
589                                a_device_ptr as *const _, lda,
590                                (&beta) as *const _,
591                                c_device_ptr as *mut _, m as c_int);
592                            match res {
593                                Ok(()) => Ok(()),
594                                Err(err) => Err(Error::Cublas(err)),
595                            }
596                        },
597                        None => Err(Error::NoCublas),
598                    }
599                }
600        })
601    }
602}
603
604impl Backend for CudaBackend
605{
606    fn name(&self) -> &'static str
607    {
608        if self.has_cublas {
609            "CUDA(cuBLAS)"
610        } else if self.has_mma {
611            "CUDA(mma)"
612        } else {
613            "CUDA"
614        }
615    }
616    
617    fn has_cublas(&self) -> bool
618    { self.has_cublas }
619
620    unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
621    {
622        let inner_g = mutex_lock(&self.inner)?;
623        let slice: CudaSlice<f32> = match inner_g.device.alloc(n) {
624            Ok(tmp_slice) => tmp_slice,
625            Err(err) => return Err(Error::Cuda(err)),
626        };
627        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
628        Ok(BackendArray::Cuda(cuda_array))
629    }
630
631    fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
632    {
633        let inner_g = mutex_lock(&self.inner)?;
634        let slice: CudaSlice<f32> = match inner_g.device.alloc_zeros(n) {
635            Ok(tmp_slice) => tmp_slice,
636            Err(err) => return Err(Error::Cuda(err)),
637        };
638        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
639        Ok(BackendArray::Cuda(cuda_array))
640    }
641    
642    fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
643    {
644        let inner_g = mutex_lock(&self.inner)?;
645        let slice: CudaSlice<f32> = match inner_g.device.htod_sync_copy(elems) {
646            Ok(tmp_slice) => tmp_slice,
647            Err(err) => return Err(Error::Cuda(err)),
648        };
649        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: elems.len(), };
650        Ok(BackendArray::Cuda(cuda_array))
651    }
652    
653    fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
654    {
655        #[allow(unreachable_patterns)]
656        match a {
657            BackendArray::Cuda(a2) => {
658                if a2.len != elems.len() {
659                    return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
660                }
661                let inner_g = mutex_lock(&self.inner)?;
662                let a_slice_g = mutex_lock(&a2.slice)?;
663                match inner_g.device.dtoh_sync_copy_into(&(*a_slice_g), elems) {
664                    Ok(()) => (),
665                    Err(err) => return Err(Error::Cuda(err)),
666                }
667            },
668            _ => return Err(Error::InvalidBackendArray),
669        }
670        Ok(())
671    }
672
673    fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
674    {
675        #[allow(unreachable_patterns)]
676        match a {
677            BackendArray::Cuda(a2) => {
678                if a2.len != elems.len() {
679                    return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
680                }
681                let inner_g = mutex_lock(&self.inner)?;
682                let mut a_slice_g = mutex_lock(&a2.slice)?;
683                match inner_g.device.htod_sync_copy_into(elems, &mut (*a_slice_g)) {
684                    Ok(()) => (),
685                    Err(err) => return Err(Error::Cuda(err)),
686                }
687            },
688            _ => return Err(Error::InvalidBackendArray),
689        }
690        Ok(())
691    }
692    
693    fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
694    {
695        #[allow(unreachable_patterns)]
696        match (a, b) {
697            (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
698                if Arc::ptr_eq(&a2.slice, &b2.slice) {
699                    return Ok(());
700                }
701                if a2.len != b2.len {
702                    return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
703                }
704                let inner_g = mutex_lock(&self.inner)?;
705                let a_slice_g = mutex_lock(&a2.slice)?;
706                let mut b_slice_g = mutex_lock(&b2.slice)?;
707                match inner_g.device.dtod_copy(&(*a_slice_g), &mut (*b_slice_g)) {
708                    Ok(()) => (),
709                    Err(err) => return Err(Error::Cuda(err)),
710                }
711                match inner_g.device.synchronize() {
712                    Ok(()) => (),
713                    Err(err) => return Err(Error::Cuda(err)),
714                }
715            },
716            _ => return Err(Error::InvalidBackendArray),
717        }
718        Ok(())
719    }
720
721    fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
722    { self.check_and_launch_for_fun("transpose_a", a, b, n, m) }
723
724    fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
725    { self.check_and_launch_for_op("add_a_b", a, b, c, n, m) }
726
727    fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
728    { self.check_and_launch_for_op("add_at_b", a, b, c, n, m) }
729    
730    fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
731    { self.check_and_launch_for_op("add_a_bt", a, b, c, n, m) }
732
733    fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
734    { self.check_and_launch_for_op("add_at_bt", a, b, c, n, m) }
735
736    fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
737    { self.check_and_launch_for_op("sub_a_b", a, b, c, n, m) }
738
739    fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
740    { self.check_and_launch_for_op("sub_at_b", a, b, c, n, m) }
741    
742    fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
743    { self.check_and_launch_for_op("sub_a_bt", a, b, c, n, m) }
744
745    fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>    
746    { self.check_and_launch_for_op("sub_at_bt", a, b, c, n, m) }
747    
748    fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
749    {
750        if self.has_cublas {
751            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, false)
752        } else {
753            self.check_and_launch_for_mul("mul_a_b", a, b, c, n, m, l)
754        }
755    }
756
757    fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
758    {
759        if self.has_cublas {
760            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, false)
761        } else {
762            self.check_and_launch_for_mul("mul_at_b", a, b, c, n, m, l)
763        }
764    }
765
766    fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
767    {
768        if self.has_cublas {
769            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, true)
770        } else {
771            self.check_and_launch_for_mul("mul_a_bt", a, b, c, n, m, l) 
772        }
773    }
774
775    fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
776    {
777        if self.has_cublas {
778            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, true)
779        } else {
780            self.check_and_launch_for_mul("mul_at_bt", a, b, c, n, m, l)
781        }
782    }
783
784    fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
785    { self.check_and_launch_for_op("mul_a_b_for_elems", a, b, c, n, m) }
786
787    fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
788    { self.check_and_launch_for_op("mul_at_b_for_elems", a, b, c, n, m) }
789    
790    fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
791    { self.check_and_launch_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
792    
793    fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
794    { self.check_and_launch_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
795
796    fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
797    { self.check_and_launch_for_op("div_a_b_for_elems", a, b, c, n, m) }
798
799    fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
800    { self.check_and_launch_for_op("div_at_b_for_elems", a, b, c, n, m) }
801    
802    fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
803    { self.check_and_launch_for_op("div_a_bt_for_elems", a, b, c, n, m) }
804    
805    fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
806    { self.check_and_launch_for_op("div_at_bt_for_elems", a, b, c, n, m) }
807
808    fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
809    { self.check_and_launch_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
810
811    fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
812    { self.check_and_launch_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
813
814    fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
815    { self.check_and_launch_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
816
817    fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
818    { self.check_and_launch_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
819
820    fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
821    { self.check_and_launch_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
822
823    fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
824    { self.check_and_launch_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
825    
826    fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
827    { self.check_and_launch_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
828
829    fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
830    { self.check_and_launch_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
831
832    fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
833    { self.check_and_launch_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
834
835    fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
836    { self.check_and_launch_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
837
838    fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
839    { self.check_and_launch_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
840
841    fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
842    { self.check_and_launch_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
843
844    fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
845    { self.check_and_launch_for_fun("sigmoid_a", a, b, n, m) }
846
847    fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
848    { self.check_and_launch_for_fun("sigmoid_at", a, b, n, m) }
849
850    fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
851    { self.check_and_launch_for_fun("tanh_a", a, b, n, m) }
852
853    fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
854    { self.check_and_launch_for_fun("tanh_at", a, b, n, m) }
855
856    fn swish_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
857    { self.check_and_launch_for_fun("swish_a", a, b, n, m) }
858
859    fn swish_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
860    { self.check_and_launch_for_fun("swish_at", a, b, n, m) }
861
862    fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
863    { self.check_and_launch_for_fun_and_tiles("softmax_a", a, b, n, m) }
864
865    fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
866    { self.check_and_launch_for_fun_and_tiles("softmax_at", a, b, n, m) }
867
868    fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
869    { self.check_and_launch_for_repeat_col("repeat_col_a", a, b, n, m) }
870
871    fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
872    { self.check_and_launch_for_repeat_row("repeat_row_a", a, b, n, m) }
873}
874
875#[cfg(test)]
876mod tests;