unmtx_gpu/
cuda.rs

1//
2// Copyright (c) 2025 Ɓukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8//! A module that contains a CUDA backend.
9use std::default::Default;
10use std::ffi::c_int;
11use std::ffi::c_void;
12use std::sync::Arc;
13use std::sync::Mutex;
14use crate::Backend;
15use crate::BackendArray;
16use crate::Error;
17use crate::Result;
18use crate::mutex_lock;
19
20pub use cudarc::cublas::result::CublasError;
21pub use cudarc::driver::DriverError;
22
23use cudarc::cublas::result::sgemm;
24use cudarc::cublas::sys::cublasOperation_t;
25use cudarc::cublas::CudaBlas;
26use cudarc::driver::sys::CUdeviceptr;
27use cudarc::driver::CudaDevice;
28use cudarc::driver::CudaFunction;
29use cudarc::driver::CudaSlice;
30use cudarc::driver::DeviceRepr;
31use cudarc::driver::DevicePtr;
32use cudarc::driver::DevicePtrMut;
33use cudarc::driver::LaunchAsync;
34use cudarc::driver::LaunchConfig;
35use cudarc::nvrtc::CompileError;
36use cudarc::nvrtc::CompileOptions;
37use cudarc::nvrtc::compile_ptx_with_opts;
38
39const SOURCE: &'static str = include_str!("cuda.cu");
40
41const KERNELS: &'static [&'static str] = &[
42    "transpose_a",
43    "add_a_b",
44    "add_at_b",
45    "add_a_bt",
46    "add_at_bt",
47    "sub_a_b",
48    "sub_at_b",
49    "sub_a_bt",
50    "sub_at_bt",
51    "mul_a_b",
52    "mul_at_b",
53    "mul_a_bt",
54    "mul_at_bt",
55    "mul_a_b_for_elems",
56    "mul_at_b_for_elems",
57    "mul_a_bt_for_elems",
58    "mul_at_bt_for_elems",
59    "div_a_b_for_elems",
60    "div_at_b_for_elems",
61    "div_a_bt_for_elems",
62    "div_at_bt_for_elems",
63    "add_a_b_for_scalar",
64    "add_at_b_for_scalar",
65    "sub_a_b_for_scalar",
66    "sub_at_b_for_scalar",
67    "rsub_a_b_for_scalar",
68    "rsub_at_b_for_scalar",
69    "mul_a_b_for_scalar",
70    "mul_at_b_for_scalar",
71    "div_a_b_for_scalar",
72    "div_at_b_for_scalar",
73    "rdiv_a_b_for_scalar",
74    "rdiv_at_b_for_scalar",
75    "sigmoid_a",
76    "sigmoid_at",
77    "tanh_a",
78    "tanh_at",
79    "swish_a",
80    "swish_at",
81    "softmax_a",
82    "softmax_at",
83    "sqrt_a",
84    "sqrt_at",
85    "repeat_col_a",
86    "repeat_row_a",
87    "abs_a",
88    "abs_at",
89    "pow_a_b",
90    "pow_at_b",
91    "pow_a_bt",
92    "pow_at_bt",
93    "pow_a_b_for_scalar",
94    "pow_at_b_for_scalar",
95    "rpow_a_b_for_scalar",
96    "rpow_at_b_for_scalar",
97    "exp_a",
98    "exp_at",
99    "ln_a",
100    "ln_at",
101    "log2_a",
102    "log2_at",
103    "log10_a",
104    "log10_at",
105    "sin_a",
106    "sin_at",
107    "cos_a",
108    "cos_at",
109    "tan_a",
110    "tan_at",
111    "asin_a",
112    "asin_at",
113    "acos_a",
114    "acos_at",
115    "atan_a",
116    "atan_at",
117    "atan2_a_b",
118    "atan2_at_b",
119    "atan2_a_bt",
120    "atan2_at_bt",
121    "atan2_a_b_for_scalar",
122    "atan2_at_b_for_scalar",
123    "ratan2_a_b_for_scalar",
124    "ratan2_at_b_for_scalar",
125    "sinh_a",
126    "sinh_at",
127    "cosh_a",
128    "cosh_at",
129    "asinh_a",
130    "asinh_at",
131    "acosh_a",
132    "acosh_at",
133    "atanh_a",
134    "atanh_at",
135    "signum_a",
136    "signum_at",
137    "ceil_a",
138    "ceil_at",
139    "floor_a",
140    "floor_at",
141    "round_a",
142    "round_at",
143    "trunc_a",
144    "trunc_at",
145    "max_a_b",
146    "max_at_b",
147    "max_a_bt",
148    "max_at_bt",
149    "max_a_b_for_scalar",
150    "max_at_b_for_scalar",
151    "min_a_b",
152    "min_at_b",
153    "min_a_bt",
154    "min_at_bt",
155    "min_a_b_for_scalar",
156    "min_at_b_for_scalar"
157];
158
159/// A structure of CUDA backend array.
160///
161/// This structure contains the reference to the device memory.
162#[derive(Debug)]
163pub struct CudaBackendArray
164{
165    slice: Arc<Mutex<CudaSlice<f32>>>,
166    len: usize,
167}
168
169struct CudaInnerBackend
170{
171    device: Arc<CudaDevice>,
172    cublas: Option<CudaBlas>,
173}
174
175/// A structure of CUDA backend.
176pub struct CudaBackend
177{
178    inner: Mutex<CudaInnerBackend>,
179    has_cublas: bool,
180    has_mma: bool,
181}
182
183fn preferred_launch_config(n: usize, m: usize, is_mul: bool, is_mma: bool) -> LaunchConfig
184{
185    if m == 1 && !is_mul {
186        let n2 = ((n + 1023) / 1024) as u32;
187        LaunchConfig {
188            grid_dim: (n2, 1, 1),
189            block_dim: (1024, 1, 1),
190            shared_mem_bytes: 0,
191        }
192    } else if n == 1 && !is_mul {
193        let m2 = ((m + 1023) / 1024) as u32;
194        LaunchConfig {
195            grid_dim: (1, m2, 1),
196            block_dim: (1, 1024, 1),
197            shared_mem_bytes: 0,
198        }
199    } else if is_mul {
200        if is_mma {
201            let n2 = ((n + 63) / 64) as u32;
202            let m2 = ((m + 63) / 64) as u32;
203            LaunchConfig {
204                grid_dim: (n2, m2, 1),
205                block_dim: (1024, 1, 1),
206                shared_mem_bytes: 0,
207            }
208        } else {
209            let n2 = (((n + 3) / 4 + 15) / 16) as u32;
210            let m2 = (((m + 3) / 4 + 15) / 16) as u32;
211            LaunchConfig {
212                grid_dim: (n2, m2, 1),
213                block_dim: (16, 16, 1),
214                shared_mem_bytes: 0,
215            }
216        }
217    } else {
218        let n2 = ((n + 31) / 32) as u32;
219        let m2 = ((m + 31) / 32) as u32;
220        LaunchConfig {
221            grid_dim: (n2, m2, 1),
222            block_dim: (32, 32, 1),
223            shared_mem_bytes: 0,
224        }
225    }
226}
227
228impl CudaBackend
229{
230    /// Creates a CUDA backend for a first device.
231    pub fn new() -> Result<CudaBackend>
232    {
233        if cfg!(feature = "default_cublas") {
234            Self::new_with_ordinal_and_flags(0, true, false)
235        } else if cfg!(feature = "default_mma") {
236            Self::new_with_ordinal_and_flags(0, false, true)
237        } else {
238            Self::new_with_ordinal_and_flags(0, false, false)
239        }
240    }
241    
242    /// Creates a CUDA backend with the ordinal number and the flags.
243    ///
244    /// This method takes the following flags:
245    ///
246    /// - `is_cublas` - use the cuBLAS library to multiplication of matrices
247    /// - `is_mma` - use the mma instruction to multiplication of matrices
248    pub fn new_with_ordinal_and_flags(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<CudaBackend>
249    {
250        let device = match CudaDevice::new(ordinal) {
251            Ok(tmp_device) => tmp_device,
252            Err(err) => return Err(Error::Cuda(err)),
253        };
254        let mut options: CompileOptions = Default::default();
255        if is_mma {
256            options.options = vec![String::from("-DUNMTX_GPU_MMA=1")];
257            options.arch = Some("sm_80");
258        }
259        let ptx = match compile_ptx_with_opts(SOURCE, options) {
260            Ok(tmp_ptx) => tmp_ptx,
261            Err(CompileError::CompileError { log, .. }) => return Err(Error::Compilation(log.as_c_str().to_string_lossy().into_owned())),
262            Err(err) => return Err(Error::Compilation(format!("{}", err))),
263        };
264        match device.load_ptx(ptx, "unmtx_gpu", KERNELS) {
265            Ok(()) => (),
266            Err(err) => return Err(Error::Cuda(err)),
267        }
268        let cublas = if is_cublas {
269            match CudaBlas::new(device.clone()) {
270                Ok(tmp_cublas) => Some(tmp_cublas),
271                Err(err) => return Err(Error::Cublas(err)),
272            }
273        } else {
274            None
275        };
276        Ok(CudaBackend { inner: Mutex::new(CudaInnerBackend { device, cublas, }), has_cublas: is_cublas, has_mma: is_mma, })
277    }
278    
279    pub fn has_cublas(&self) -> bool
280    { self.has_cublas }
281    
282    fn check_and_launch2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
283        where F: FnOnce(&CudaBackendArray, &CudaBackendArray) -> Result<()>,
284            G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void) -> Result<()>
285    {
286        #[allow(unreachable_patterns)]
287        match (a, b) {
288            (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
289                f(a2, b2)?;
290                let inner_g = mutex_lock(&self.inner)?;
291                let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
292                    Some(tmp_kernel) => tmp_kernel,
293                    None => return Err(Error::NoKernel(String::from(kernel_name))),
294                };
295                if !Arc::ptr_eq(&a2.slice, &b2.slice) {
296                    let a_slice_g = mutex_lock(&a2.slice)?;
297                    let mut b_slice_g = mutex_lock(&b2.slice)?;
298                    g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?;
299                } else {
300                    let mut a_slice_g = mutex_lock(&a2.slice)?;
301                    g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?;
302                }
303                match inner_g.device.synchronize() {
304                    Ok(()) => (),
305                    Err(err) => return Err(Error::Cuda(err)),
306                }
307                Ok(())
308            },
309            _ => Err(Error::InvalidBackendArray),
310        }
311    }
312
313    fn check_and_launch3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
314        where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
315            G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void, *mut c_void) -> Result<()>
316    {
317        #[allow(unreachable_patterns)]
318        match (a, b, c) {
319            (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
320                f(a2, b2, c2)?;
321                let inner_g = mutex_lock(&self.inner)?;
322                let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
323                    Some(tmp_kernel) => tmp_kernel,
324                    None => return Err(Error::NoKernel(String::from(kernel_name))),
325                };
326                match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
327                    (false, false, false) => {
328                        let a_slice_g = mutex_lock(&a2.slice)?;
329                        let b_slice_g = mutex_lock(&b2.slice)?;
330                        let mut c_slice_g = mutex_lock(&c2.slice)?;
331                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
332                    },
333                    (true, false, false) => {
334                        let a_slice_g = mutex_lock(&a2.slice)?;
335                        let mut c_slice_g = mutex_lock(&c2.slice)?;
336                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*a_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
337                    },
338                    (false, true, false) => {
339                        let mut a_slice_g = mutex_lock(&a2.slice)?;
340                        let b_slice_g = mutex_lock(&b2.slice)?;
341                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
342                    },
343                    (false, false, true) => {
344                        let a_slice_g = mutex_lock(&a2.slice)?;
345                        let mut b_slice_g = mutex_lock(&b2.slice)?;
346                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?
347                    },
348                    _ => {
349                        let mut a_slice_g = mutex_lock(&a2.slice)?;
350                        g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
351                    },
352                }
353                match inner_g.device.synchronize() {
354                    Ok(()) => (),
355                    Err(err) => return Err(Error::Cuda(err)),
356                }
357                Ok(())
358            },
359            _ => Err(Error::InvalidBackendArray),
360        }
361    }    
362
363    fn check_and_launch_cublas3<F, G>(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
364        where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
365            G: FnOnce(&CudaInnerBackend, CUdeviceptr, CUdeviceptr, CUdeviceptr) -> Result<()>
366    {
367        #[allow(unreachable_patterns)]
368        match (a, b, c) {
369            (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
370                f(a2, b2, c2)?;
371                let inner_g = mutex_lock(&self.inner)?;
372                match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
373                    (false, false, false) => {
374                        let a_slice_g = mutex_lock(&a2.slice)?;
375                        let b_slice_g = mutex_lock(&b2.slice)?;
376                        let mut c_slice_g = mutex_lock(&c2.slice)?;
377                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
378                        let b_device_ptr = *(&(*b_slice_g)).device_ptr();
379                        let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
380                        g(&*inner_g, a_device_ptr, b_device_ptr, c_device_ptr)?
381                    },
382                    (true, false, false) => {
383                        let a_slice_g = mutex_lock(&a2.slice)?;
384                        let mut c_slice_g = mutex_lock(&c2.slice)?;
385                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
386                        let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
387                        g(&*inner_g, a_device_ptr, a_device_ptr, c_device_ptr)?
388                    },
389                    (false, true, false) => {
390                        let mut a_slice_g = mutex_lock(&a2.slice)?;
391                        let b_slice_g = mutex_lock(&b2.slice)?;
392                        let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
393                        let b_device_ptr = *(&(*b_slice_g)).device_ptr();
394                        g(&*inner_g, a_device_ptr, b_device_ptr, a_device_ptr)?
395                    },
396                    (false, false, true) => {
397                        let a_slice_g = mutex_lock(&a2.slice)?;
398                        let mut b_slice_g = mutex_lock(&b2.slice)?;
399                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
400                        let b_device_ptr = *(&mut (*b_slice_g)).device_ptr_mut();
401                        g(&*inner_g, a_device_ptr, b_device_ptr, b_device_ptr)?
402                    },
403                    _ => {
404                        let mut a_slice_g = mutex_lock(&a2.slice)?;
405                        let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
406                        g(&*inner_g, a_device_ptr, a_device_ptr, a_device_ptr)?
407                    },
408                }
409                match inner_g.device.synchronize() {
410                    Ok(()) => (),
411                    Err(err) => return Err(Error::Cuda(err)),
412                }
413                Ok(())
414            },
415            _ => Err(Error::InvalidBackendArray),
416        }
417    }
418    
419    fn check_and_launch_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
420    {
421        let is_mma = self.has_mma;
422        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
423                if a2.len != n * m {
424                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
425                }
426                if b2.len != n * m {
427                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
428                }
429                Ok(())
430        }, |_, kernel, a_param, b_param| {
431                let config = preferred_launch_config(n, m, false, is_mma);
432                let mut params = vec![
433                    a_param,
434                    b_param,
435                    n.as_kernel_param(),
436                    m.as_kernel_param()
437                ];
438                unsafe {
439                    match kernel.launch(config, &mut params) {
440                        Ok(()) => Ok(()),
441                        Err(err) => Err(Error::Cuda(err)),
442                    }
443                }
444        })
445    }
446
447    fn check_and_launch_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
448    {
449        let is_mma = self.has_mma;
450        self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
451                if a2.len != n * m {
452                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
453                }
454                if b2.len != n * m {
455                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
456                }
457                if c2.len != n * m {
458                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
459                }
460                Ok(())
461        }, |_, kernel, a_param, b_param, c_param| {
462                let config = preferred_launch_config(n, m, false, is_mma);
463                let mut params = vec![
464                    a_param,
465                    b_param,
466                    c_param,
467                    n.as_kernel_param(),
468                    m.as_kernel_param()
469                ];
470                unsafe {
471                    match kernel.launch(config, &mut params) {
472                        Ok(()) => Ok(()),
473                        Err(err) => Err(Error::Cuda(err)),
474                    }
475                }
476        })
477    }
478
479    fn check_and_launch_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
480    {
481        let is_mma = self.has_mma;
482        self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
483                if a2.len != n * l {
484                    return Err(Error::BackendArrayElemCount(a2.len, n * l));
485                }
486                if b2.len != l * m {
487                    return Err(Error::BackendArrayElemCount(b2.len, l * m));
488                }
489                if c2.len != n * m {
490                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
491                }
492                Ok(())
493        }, |_, kernel, a_param, b_param, c_param| {
494                let config = preferred_launch_config(n, m, true, is_mma);
495                let mut params = vec![
496                    a_param,
497                    b_param,
498                    c_param,
499                    n.as_kernel_param(),
500                    m.as_kernel_param(),
501                    l.as_kernel_param()
502                ];
503                unsafe {
504                    match kernel.launch(config, &mut params) {
505                        Ok(()) => Ok(()),
506                        Err(err) => Err(Error::Cuda(err)),
507                    }
508                }
509        })
510    }
511
512    fn check_and_launch_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
513    {
514        let is_mma = self.has_mma;
515        self.check_and_launch2(kernel_name, a, c, |a2, c2| {
516                if a2.len != n * m  {
517                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
518                }
519                if c2.len != n * m {
520                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
521                }
522                Ok(())
523        }, |_, kernel, a_param, c_param| {
524                let config = preferred_launch_config(n, m, false, is_mma);
525                let mut params = vec![
526                    a_param,
527                    b.as_kernel_param(),
528                    c_param,
529                    n.as_kernel_param(),
530                    m.as_kernel_param()
531                ];
532                unsafe {
533                    match kernel.launch(config, &mut params) {
534                        Ok(()) => Ok(()),
535                        Err(err) => Err(Error::Cuda(err)),
536                    }
537                }
538        })
539    }
540
541    fn check_and_launch_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
542    {
543        let is_mma = self.has_mma;
544        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
545                if a2.len != n * m {
546                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
547                }
548                if b2.len != n * m {
549                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
550                }
551                Ok(())
552        }, |_, kernel, a_param, b_param| {
553                let config = preferred_launch_config(n, m, false, is_mma);
554                let mut params = vec![
555                    a_param,
556                    b_param,
557                    n.as_kernel_param(),
558                    m.as_kernel_param(),
559                    ((config.block_dim.1) as usize).as_kernel_param(),
560                    ((config.block_dim.0) as usize).as_kernel_param()
561                ];
562                unsafe {
563                    match kernel.launch(config, &mut params) {
564                        Ok(()) => Ok(()),
565                        Err(err) => Err(Error::Cuda(err)),
566                    }
567                }
568        })
569    }
570
571    fn check_and_launch_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
572    {
573        let is_mma = self.has_mma;
574        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
575                if a2.len != n {
576                    return Err(Error::BackendArrayElemCount(a2.len, n));
577                }
578                if b2.len != n * m {
579                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
580                }
581                Ok(())
582        }, |_, kernel, a_param, b_param| {
583                let config = preferred_launch_config(n, m, false, is_mma);
584                let mut params = vec![
585                    a_param,
586                    b_param,
587                    n.as_kernel_param(),
588                    m.as_kernel_param()
589                ];
590                unsafe {
591                    match kernel.launch(config, &mut params) {
592                        Ok(()) => Ok(()),
593                        Err(err) => Err(Error::Cuda(err)),
594                    }
595                }
596        })
597    }
598
599    fn check_and_launch_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
600    {
601        let is_mma = self.has_mma;
602        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
603                if a2.len != m {
604                    return Err(Error::BackendArrayElemCount(a2.len, m));
605                }
606                if b2.len != n * m {
607                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
608                }
609                Ok(())
610        }, |_, kernel, a_param, b_param| {
611                let config = preferred_launch_config(n, m, false, is_mma);
612                let mut params = vec![
613                    a_param,
614                    b_param,
615                    n.as_kernel_param(),
616                    m.as_kernel_param()
617                ];
618                unsafe {
619                    match kernel.launch(config, &mut params) {
620                        Ok(()) => Ok(()),
621                        Err(err) => Err(Error::Cuda(err)),
622                    }
623                }
624        })
625    }    
626    
627    fn check_and_launch_cublas_for_mul(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize, is_trans_a: bool, is_trans_b: bool) -> Result<()>
628    {
629        self.check_and_launch_cublas3(a, b, c, |a2, b2, c2| {
630                if a2.len != n * l {
631                    return Err(Error::BackendArrayElemCount(a2.len, n * l));
632                }
633                if b2.len != l * m {
634                    return Err(Error::BackendArrayElemCount(b2.len, l * m));
635                }
636                if c2.len != n * m {
637                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
638                }
639                Ok(())
640        }, |inner, a_device_ptr, b_device_ptr, c_device_ptr| {
641                unsafe {
642                    match &inner.cublas {
643                        Some(cublas) => {
644                            let (transa, lda) = if is_trans_a {
645                                (cublasOperation_t::CUBLAS_OP_T, n as c_int)
646                            } else {
647                                (cublasOperation_t::CUBLAS_OP_N, l as c_int)
648                            };
649                            let (transb, ldb) = if is_trans_b {
650                                (cublasOperation_t::CUBLAS_OP_T, l as c_int)
651                            } else {
652                                (cublasOperation_t::CUBLAS_OP_N, m as c_int)
653                            };
654                            let alpha = 1.0f32;
655                            let beta = 0.0f32;
656                            let res = sgemm(*cublas.handle(),
657                                transb, transa,
658                                m as c_int, n as c_int, l as c_int,
659                                (&alpha) as *const _,
660                                b_device_ptr as *const _, ldb,
661                                a_device_ptr as *const _, lda,
662                                (&beta) as *const _,
663                                c_device_ptr as *mut _, m as c_int);
664                            match res {
665                                Ok(()) => Ok(()),
666                                Err(err) => Err(Error::Cublas(err)),
667                            }
668                        },
669                        None => Err(Error::NoCublas),
670                    }
671                }
672        })
673    }
674}
675
676impl Backend for CudaBackend
677{
678    fn name(&self) -> &'static str
679    {
680        if self.has_cublas {
681            "CUDA(cuBLAS)"
682        } else if self.has_mma {
683            "CUDA(mma)"
684        } else {
685            "CUDA"
686        }
687    }
688    
689    fn has_cublas(&self) -> bool
690    { self.has_cublas }
691
692    unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
693    {
694        let inner_g = mutex_lock(&self.inner)?;
695        let slice: CudaSlice<f32> = match inner_g.device.alloc(n) {
696            Ok(tmp_slice) => tmp_slice,
697            Err(err) => return Err(Error::Cuda(err)),
698        };
699        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
700        Ok(BackendArray::Cuda(cuda_array))
701    }
702
703    fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
704    {
705        let inner_g = mutex_lock(&self.inner)?;
706        let slice: CudaSlice<f32> = match inner_g.device.alloc_zeros(n) {
707            Ok(tmp_slice) => tmp_slice,
708            Err(err) => return Err(Error::Cuda(err)),
709        };
710        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
711        Ok(BackendArray::Cuda(cuda_array))
712    }
713    
714    fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
715    {
716        let inner_g = mutex_lock(&self.inner)?;
717        let slice: CudaSlice<f32> = match inner_g.device.htod_sync_copy(elems) {
718            Ok(tmp_slice) => tmp_slice,
719            Err(err) => return Err(Error::Cuda(err)),
720        };
721        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: elems.len(), };
722        Ok(BackendArray::Cuda(cuda_array))
723    }
724    
725    fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
726    {
727        #[allow(unreachable_patterns)]
728        match a {
729            BackendArray::Cuda(a2) => {
730                if a2.len != elems.len() {
731                    return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
732                }
733                let inner_g = mutex_lock(&self.inner)?;
734                let a_slice_g = mutex_lock(&a2.slice)?;
735                match inner_g.device.dtoh_sync_copy_into(&(*a_slice_g), elems) {
736                    Ok(()) => (),
737                    Err(err) => return Err(Error::Cuda(err)),
738                }
739            },
740            _ => return Err(Error::InvalidBackendArray),
741        }
742        Ok(())
743    }
744
745    fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
746    {
747        #[allow(unreachable_patterns)]
748        match a {
749            BackendArray::Cuda(a2) => {
750                if a2.len != elems.len() {
751                    return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
752                }
753                let inner_g = mutex_lock(&self.inner)?;
754                let mut a_slice_g = mutex_lock(&a2.slice)?;
755                match inner_g.device.htod_sync_copy_into(elems, &mut (*a_slice_g)) {
756                    Ok(()) => (),
757                    Err(err) => return Err(Error::Cuda(err)),
758                }
759            },
760            _ => return Err(Error::InvalidBackendArray),
761        }
762        Ok(())
763    }
764    
765    fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
766    {
767        #[allow(unreachable_patterns)]
768        match (a, b) {
769            (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
770                if Arc::ptr_eq(&a2.slice, &b2.slice) {
771                    return Ok(());
772                }
773                if a2.len != b2.len {
774                    return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
775                }
776                let inner_g = mutex_lock(&self.inner)?;
777                let a_slice_g = mutex_lock(&a2.slice)?;
778                let mut b_slice_g = mutex_lock(&b2.slice)?;
779                match inner_g.device.dtod_copy(&(*a_slice_g), &mut (*b_slice_g)) {
780                    Ok(()) => (),
781                    Err(err) => return Err(Error::Cuda(err)),
782                }
783                match inner_g.device.synchronize() {
784                    Ok(()) => (),
785                    Err(err) => return Err(Error::Cuda(err)),
786                }
787            },
788            _ => return Err(Error::InvalidBackendArray),
789        }
790        Ok(())
791    }
792
793    fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
794    { self.check_and_launch_for_fun("transpose_a", a, b, n, m) }
795
796    fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
797    { self.check_and_launch_for_op("add_a_b", a, b, c, n, m) }
798
799    fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
800    { self.check_and_launch_for_op("add_at_b", a, b, c, n, m) }
801    
802    fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
803    { self.check_and_launch_for_op("add_a_bt", a, b, c, n, m) }
804
805    fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
806    { self.check_and_launch_for_op("add_at_bt", a, b, c, n, m) }
807
808    fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
809    { self.check_and_launch_for_op("sub_a_b", a, b, c, n, m) }
810
811    fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
812    { self.check_and_launch_for_op("sub_at_b", a, b, c, n, m) }
813    
814    fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
815    { self.check_and_launch_for_op("sub_a_bt", a, b, c, n, m) }
816
817    fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>    
818    { self.check_and_launch_for_op("sub_at_bt", a, b, c, n, m) }
819    
820    fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
821    {
822        if self.has_cublas {
823            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, false)
824        } else {
825            self.check_and_launch_for_mul("mul_a_b", a, b, c, n, m, l)
826        }
827    }
828
829    fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
830    {
831        if self.has_cublas {
832            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, false)
833        } else {
834            self.check_and_launch_for_mul("mul_at_b", a, b, c, n, m, l)
835        }
836    }
837
838    fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
839    {
840        if self.has_cublas {
841            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, true)
842        } else {
843            self.check_and_launch_for_mul("mul_a_bt", a, b, c, n, m, l) 
844        }
845    }
846
847    fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
848    {
849        if self.has_cublas {
850            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, true)
851        } else {
852            self.check_and_launch_for_mul("mul_at_bt", a, b, c, n, m, l)
853        }
854    }
855
856    fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
857    { self.check_and_launch_for_op("mul_a_b_for_elems", a, b, c, n, m) }
858
859    fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
860    { self.check_and_launch_for_op("mul_at_b_for_elems", a, b, c, n, m) }
861    
862    fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
863    { self.check_and_launch_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
864    
865    fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
866    { self.check_and_launch_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
867
868    fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
869    { self.check_and_launch_for_op("div_a_b_for_elems", a, b, c, n, m) }
870
871    fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
872    { self.check_and_launch_for_op("div_at_b_for_elems", a, b, c, n, m) }
873    
874    fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
875    { self.check_and_launch_for_op("div_a_bt_for_elems", a, b, c, n, m) }
876    
877    fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
878    { self.check_and_launch_for_op("div_at_bt_for_elems", a, b, c, n, m) }
879
880    fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
881    { self.check_and_launch_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
882
883    fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
884    { self.check_and_launch_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
885
886    fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
887    { self.check_and_launch_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
888
889    fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
890    { self.check_and_launch_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
891
892    fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
893    { self.check_and_launch_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
894
895    fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
896    { self.check_and_launch_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
897    
898    fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
899    { self.check_and_launch_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
900
901    fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
902    { self.check_and_launch_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
903
904    fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
905    { self.check_and_launch_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
906
907    fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
908    { self.check_and_launch_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
909
910    fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
911    { self.check_and_launch_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
912
913    fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
914    { self.check_and_launch_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
915
916    fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
917    { self.check_and_launch_for_fun("sigmoid_a", a, b, n, m) }
918
919    fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
920    { self.check_and_launch_for_fun("sigmoid_at", a, b, n, m) }
921
922    fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
923    { self.check_and_launch_for_fun("tanh_a", a, b, n, m) }
924
925    fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
926    { self.check_and_launch_for_fun("tanh_at", a, b, n, m) }
927
928    fn swish_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
929    { self.check_and_launch_for_fun("swish_a", a, b, n, m) }
930
931    fn swish_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
932    { self.check_and_launch_for_fun("swish_at", a, b, n, m) }
933
934    fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
935    { self.check_and_launch_for_fun_and_tiles("softmax_a", a, b, n, m) }
936
937    fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
938    { self.check_and_launch_for_fun_and_tiles("softmax_at", a, b, n, m) }
939
940    fn sqrt_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
941    { self.check_and_launch_for_fun("sqrt_a", a, b, n, m) }
942
943    fn sqrt_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
944    { self.check_and_launch_for_fun("sqrt_at", a, b, n, m) }
945
946    fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
947    { self.check_and_launch_for_repeat_col("repeat_col_a", a, b, n, m) }
948
949    fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
950    { self.check_and_launch_for_repeat_row("repeat_row_a", a, b, n, m) }
951
952    fn abs_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
953    { self.check_and_launch_for_fun("abs_a", a, b, n, m) }
954
955    fn abs_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
956    { self.check_and_launch_for_fun("abs_at", a, b, n, m) }
957
958    fn pow_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
959    { self.check_and_launch_for_op("pow_a_b", a, b, c, n, m) }
960
961    fn pow_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
962    { self.check_and_launch_for_op("pow_at_b", a, b, c, n, m) }
963    
964    fn pow_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
965    { self.check_and_launch_for_op("pow_a_bt", a, b, c, n, m) }
966    
967    fn pow_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
968    { self.check_and_launch_for_op("pow_at_bt", a, b, c, n, m) }
969
970    fn pow_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
971    { self.check_and_launch_for_scalar("pow_a_b_for_scalar", a, b, c, n, m) }
972
973    fn pow_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
974    { self.check_and_launch_for_scalar("pow_at_b_for_scalar", a, b, c, n, m) }
975
976    fn rpow_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
977    { self.check_and_launch_for_scalar("rpow_a_b_for_scalar", a, b, c, n, m) }
978
979    fn rpow_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
980    { self.check_and_launch_for_scalar("rpow_at_b_for_scalar", a, b, c, n, m) }
981
982    fn exp_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
983    { self.check_and_launch_for_fun("exp_a", a, b, n, m) }
984
985    fn exp_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
986    { self.check_and_launch_for_fun("exp_at", a, b, n, m) }
987
988    fn ln_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
989    { self.check_and_launch_for_fun("ln_a", a, b, n, m) }
990
991    fn ln_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
992    { self.check_and_launch_for_fun("ln_at", a, b, n, m) }
993
994    fn log2_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
995    { self.check_and_launch_for_fun("log2_a", a, b, n, m) }
996
997    fn log2_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
998    { self.check_and_launch_for_fun("log2_at", a, b, n, m) }
999
1000    fn log10_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1001    { self.check_and_launch_for_fun("log10_a", a, b, n, m) }
1002
1003    fn log10_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1004    { self.check_and_launch_for_fun("log10_at", a, b, n, m) }
1005
1006    fn sin_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1007    { self.check_and_launch_for_fun("sin_a", a, b, n, m) }
1008
1009    fn sin_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1010    { self.check_and_launch_for_fun("sin_at", a, b, n, m) }
1011
1012    fn cos_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1013    { self.check_and_launch_for_fun("cos_a", a, b, n, m) }
1014
1015    fn cos_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1016    { self.check_and_launch_for_fun("cos_at", a, b, n, m) }
1017
1018    fn tan_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1019    { self.check_and_launch_for_fun("tan_a", a, b, n, m) }
1020
1021    fn tan_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1022    { self.check_and_launch_for_fun("tan_at", a, b, n, m) }
1023
1024    fn asin_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1025    { self.check_and_launch_for_fun("asin_a", a, b, n, m) }
1026
1027    fn asin_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1028    { self.check_and_launch_for_fun("asin_at", a, b, n, m) }
1029
1030    fn acos_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1031    { self.check_and_launch_for_fun("acos_a", a, b, n, m) }
1032
1033    fn acos_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1034    { self.check_and_launch_for_fun("acos_at", a, b, n, m) }
1035
1036    fn atan_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1037    { self.check_and_launch_for_fun("atan_a", a, b, n, m) }
1038
1039    fn atan_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1040    { self.check_and_launch_for_fun("atan_at", a, b, n, m) }
1041
1042    fn atan2_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1043    { self.check_and_launch_for_op("atan2_a_b", a, b, c, n, m) }
1044
1045    fn atan2_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1046    { self.check_and_launch_for_op("atan2_at_b", a, b, c, n, m) }
1047    
1048    fn atan2_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1049    { self.check_and_launch_for_op("atan2_a_bt", a, b, c, n, m) }
1050    
1051    fn atan2_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1052    { self.check_and_launch_for_op("atan2_at_bt", a, b, c, n, m) }
1053
1054    fn atan2_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1055    { self.check_and_launch_for_scalar("atan2_a_b_for_scalar", a, b, c, n, m) }
1056
1057    fn atan2_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1058    { self.check_and_launch_for_scalar("atan2_at_b_for_scalar", a, b, c, n, m) }
1059
1060    fn ratan2_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1061    { self.check_and_launch_for_scalar("ratan2_a_b_for_scalar", a, b, c, n, m) }
1062
1063    fn ratan2_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1064    { self.check_and_launch_for_scalar("ratan2_at_b_for_scalar", a, b, c, n, m) }
1065
1066    fn sinh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1067    { self.check_and_launch_for_fun("sinh_a", a, b, n, m) }
1068
1069    fn sinh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1070    { self.check_and_launch_for_fun("sinh_at", a, b, n, m) }
1071
1072    fn cosh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1073    { self.check_and_launch_for_fun("cosh_a", a, b, n, m) }
1074
1075    fn cosh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1076    { self.check_and_launch_for_fun("cosh_at", a, b, n, m) }
1077
1078    fn asinh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1079    { self.check_and_launch_for_fun("asinh_a", a, b, n, m) }
1080
1081    fn asinh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1082    { self.check_and_launch_for_fun("asinh_at", a, b, n, m) }
1083
1084    fn acosh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1085    { self.check_and_launch_for_fun("acosh_a", a, b, n, m) }
1086
1087    fn acosh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1088    { self.check_and_launch_for_fun("acosh_at", a, b, n, m) }
1089
1090    fn atanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1091    { self.check_and_launch_for_fun("atanh_a", a, b, n, m) }
1092
1093    fn atanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1094    { self.check_and_launch_for_fun("atanh_at", a, b, n, m) }
1095
1096    fn signum_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1097    { self.check_and_launch_for_fun("signum_a", a, b, n, m) }
1098
1099    fn signum_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1100    { self.check_and_launch_for_fun("signum_at", a, b, n, m) }
1101
1102    fn ceil_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1103    { self.check_and_launch_for_fun("ceil_a", a, b, n, m) }
1104
1105    fn ceil_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1106    { self.check_and_launch_for_fun("ceil_at", a, b, n, m) }
1107
1108    fn floor_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1109    { self.check_and_launch_for_fun("floor_a", a, b, n, m) }
1110
1111    fn floor_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1112    { self.check_and_launch_for_fun("floor_at", a, b, n, m) }
1113
1114    fn round_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1115    { self.check_and_launch_for_fun("round_a", a, b, n, m) }
1116
1117    fn round_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1118    { self.check_and_launch_for_fun("round_at", a, b, n, m) }
1119
1120    fn trunc_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1121    { self.check_and_launch_for_fun("trunc_a", a, b, n, m) }
1122
1123    fn trunc_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1124    { self.check_and_launch_for_fun("trunc_at", a, b, n, m) }
1125
1126    fn max_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1127    { self.check_and_launch_for_op("max_a_b", a, b, c, n, m) }
1128
1129    fn max_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1130    { self.check_and_launch_for_op("max_at_b", a, b, c, n, m) }
1131    
1132    fn max_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1133    { self.check_and_launch_for_op("max_a_bt", a, b, c, n, m) }
1134    
1135    fn max_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1136    { self.check_and_launch_for_op("max_at_bt", a, b, c, n, m) }
1137
1138    fn max_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1139    { self.check_and_launch_for_scalar("max_a_b_for_scalar", a, b, c, n, m) }
1140
1141    fn max_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1142    { self.check_and_launch_for_scalar("max_at_b_for_scalar", a, b, c, n, m) }
1143
1144    fn min_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1145    { self.check_and_launch_for_op("min_a_b", a, b, c, n, m) }
1146
1147    fn min_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1148    { self.check_and_launch_for_op("min_at_b", a, b, c, n, m) }
1149    
1150    fn min_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1151    { self.check_and_launch_for_op("min_a_bt", a, b, c, n, m) }
1152    
1153    fn min_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1154    { self.check_and_launch_for_op("min_at_bt", a, b, c, n, m) }
1155
1156    fn min_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1157    { self.check_and_launch_for_scalar("min_a_b_for_scalar", a, b, c, n, m) }
1158
1159    fn min_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1160    { self.check_and_launch_for_scalar("min_at_b_for_scalar", a, b, c, n, m) }
1161}
1162
1163#[cfg(test)]
1164mod tests;