unmtx_gpu/
cuda.rs

1//
2// Copyright (c) 2025 Ɓukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8//! A module that contains a CUDA backend.
9use std::default::Default;
10use std::ffi::c_int;
11use std::ffi::c_void;
12use std::sync::Arc;
13use std::sync::Mutex;
14use crate::Backend;
15use crate::BackendArray;
16use crate::Error;
17use crate::Result;
18use crate::mutex_lock;
19
20pub use cudarc::cublas::result::CublasError;
21pub use cudarc::driver::DriverError;
22
23use cudarc::cublas::result::sgemm;
24use cudarc::cublas::sys::cublasOperation_t;
25use cudarc::cublas::CudaBlas;
26use cudarc::driver::sys::CUdeviceptr;
27use cudarc::driver::CudaDevice;
28use cudarc::driver::CudaFunction;
29use cudarc::driver::CudaSlice;
30use cudarc::driver::DeviceRepr;
31use cudarc::driver::DevicePtr;
32use cudarc::driver::DevicePtrMut;
33use cudarc::driver::LaunchAsync;
34use cudarc::driver::LaunchConfig;
35use cudarc::nvrtc::CompileError;
36use cudarc::nvrtc::CompileOptions;
37use cudarc::nvrtc::compile_ptx_with_opts;
38
39const SOURCE: &'static str = include_str!("cuda.cu");
40
41const KERNELS: &'static [&'static str] = &[
42    "transpose_a",
43    "add_a_b",
44    "add_at_b",
45    "add_a_bt",
46    "add_at_bt",
47    "sub_a_b",
48    "sub_at_b",
49    "sub_a_bt",
50    "sub_at_bt",
51    "mul_a_b",
52    "mul_at_b",
53    "mul_a_bt",
54    "mul_at_bt",
55    "mul_a_b_for_elems",
56    "mul_at_b_for_elems",
57    "mul_a_bt_for_elems",
58    "mul_at_bt_for_elems",
59    "div_a_b_for_elems",
60    "div_at_b_for_elems",
61    "div_a_bt_for_elems",
62    "div_at_bt_for_elems",
63    "add_a_b_for_scalar",
64    "add_at_b_for_scalar",
65    "sub_a_b_for_scalar",
66    "sub_at_b_for_scalar",
67    "rsub_a_b_for_scalar",
68    "rsub_at_b_for_scalar",
69    "mul_a_b_for_scalar",
70    "mul_at_b_for_scalar",
71    "div_a_b_for_scalar",
72    "div_at_b_for_scalar",
73    "rdiv_a_b_for_scalar",
74    "rdiv_at_b_for_scalar",
75    "sigmoid_a",
76    "sigmoid_at",
77    "tanh_a",
78    "tanh_at",
79    "swish_a",
80    "swish_at",
81    "softmax_a",
82    "softmax_at",
83    "sqrt_a",
84    "sqrt_at",
85    "repeat_col_a",
86    "repeat_row_a",
87    "abs_a",
88    "abs_at",
89    "pow_a_b",
90    "pow_at_b",
91    "pow_a_bt",
92    "pow_at_bt",
93    "pow_a_b_for_scalar",
94    "pow_at_b_for_scalar",
95    "rpow_a_b_for_scalar",
96    "rpow_at_b_for_scalar",
97    "exp_a",
98    "exp_at",
99    "ln_a",
100    "ln_at",
101    "log2_a",
102    "log2_at",
103    "log10_a",
104    "log10_at",
105    "sin_a",
106    "sin_at",
107    "cos_a",
108    "cos_at",
109    "tan_a",
110    "tan_at",
111    "asin_a",
112    "asin_at",
113    "acos_a",
114    "acos_at",
115    "atan_a",
116    "atan_at",
117    "atan2_a_b",
118    "atan2_at_b",
119    "atan2_a_bt",
120    "atan2_at_bt",
121    "atan2_a_b_for_scalar",
122    "atan2_at_b_for_scalar",
123    "ratan2_a_b_for_scalar",
124    "ratan2_at_b_for_scalar",
125    "sinh_a",
126    "sinh_at",
127    "cosh_a",
128    "cosh_at",
129    "asinh_a",
130    "asinh_at",
131    "acosh_a",
132    "acosh_at",
133    "atanh_a",
134    "atanh_at",
135    "signum_a",
136    "signum_at",
137    "ceil_a",
138    "ceil_at",
139    "floor_a",
140    "floor_at",
141    "round_a",
142    "round_at",
143    "trunc_a",
144    "trunc_at"
145];
146
147/// A structure of CUDA backend array.
148///
149/// This structure contains the reference to the device memory.
150#[derive(Debug)]
151pub struct CudaBackendArray
152{
153    slice: Arc<Mutex<CudaSlice<f32>>>,
154    len: usize,
155}
156
157struct CudaInnerBackend
158{
159    device: Arc<CudaDevice>,
160    cublas: Option<CudaBlas>,
161}
162
163/// A structure of CUDA backend.
164pub struct CudaBackend
165{
166    inner: Mutex<CudaInnerBackend>,
167    has_cublas: bool,
168    has_mma: bool,
169}
170
171fn preferred_launch_config(n: usize, m: usize, is_mul: bool, is_mma: bool) -> LaunchConfig
172{
173    if m == 1 && !is_mul {
174        let n2 = ((n + 1023) / 1024) as u32;
175        LaunchConfig {
176            grid_dim: (n2, 1, 1),
177            block_dim: (1024, 1, 1),
178            shared_mem_bytes: 0,
179        }
180    } else if n == 1 && !is_mul {
181        let m2 = ((m + 1023) / 1024) as u32;
182        LaunchConfig {
183            grid_dim: (1, m2, 1),
184            block_dim: (1, 1024, 1),
185            shared_mem_bytes: 0,
186        }
187    } else if is_mul {
188        if is_mma {
189            let n2 = ((n + 63) / 64) as u32;
190            let m2 = ((m + 63) / 64) as u32;
191            LaunchConfig {
192                grid_dim: (n2, m2, 1),
193                block_dim: (1024, 1, 1),
194                shared_mem_bytes: 0,
195            }
196        } else {
197            let n2 = (((n + 3) / 4 + 15) / 16) as u32;
198            let m2 = (((m + 3) / 4 + 15) / 16) as u32;
199            LaunchConfig {
200                grid_dim: (n2, m2, 1),
201                block_dim: (16, 16, 1),
202                shared_mem_bytes: 0,
203            }
204        }
205    } else {
206        let n2 = ((n + 31) / 32) as u32;
207        let m2 = ((m + 31) / 32) as u32;
208        LaunchConfig {
209            grid_dim: (n2, m2, 1),
210            block_dim: (32, 32, 1),
211            shared_mem_bytes: 0,
212        }
213    }
214}
215
216impl CudaBackend
217{
218    /// Creates a CUDA backend for a first device.
219    pub fn new() -> Result<CudaBackend>
220    {
221        if cfg!(feature = "default_cublas") {
222            Self::new_with_ordinal_and_flags(0, true, false)
223        } else if cfg!(feature = "default_mma") {
224            Self::new_with_ordinal_and_flags(0, false, true)
225        } else {
226            Self::new_with_ordinal_and_flags(0, false, false)
227        }
228    }
229    
230    /// Creates a CUDA backend with the ordinal number and the flags.
231    ///
232    /// This method takes the following flags:
233    ///
234    /// - `is_cublas` - use the cuBLAS library to multiplication of matrices
235    /// - `is_mma` - use the mma instruction to multiplication of matrices
236    pub fn new_with_ordinal_and_flags(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<CudaBackend>
237    {
238        let device = match CudaDevice::new(ordinal) {
239            Ok(tmp_device) => tmp_device,
240            Err(err) => return Err(Error::Cuda(err)),
241        };
242        let mut options: CompileOptions = Default::default();
243        if is_mma {
244            options.options = vec![String::from("-DUNMTX_GPU_MMA=1")];
245            options.arch = Some("sm_80");
246        }
247        let ptx = match compile_ptx_with_opts(SOURCE, options) {
248            Ok(tmp_ptx) => tmp_ptx,
249            Err(CompileError::CompileError { log, .. }) => return Err(Error::Compilation(log.as_c_str().to_string_lossy().into_owned())),
250            Err(err) => return Err(Error::Compilation(format!("{}", err))),
251        };
252        match device.load_ptx(ptx, "unmtx_gpu", KERNELS) {
253            Ok(()) => (),
254            Err(err) => return Err(Error::Cuda(err)),
255        }
256        let cublas = if is_cublas {
257            match CudaBlas::new(device.clone()) {
258                Ok(tmp_cublas) => Some(tmp_cublas),
259                Err(err) => return Err(Error::Cublas(err)),
260            }
261        } else {
262            None
263        };
264        Ok(CudaBackend { inner: Mutex::new(CudaInnerBackend { device, cublas, }), has_cublas: is_cublas, has_mma: is_mma, })
265    }
266    
267    pub fn has_cublas(&self) -> bool
268    { self.has_cublas }
269    
270    fn check_and_launch2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
271        where F: FnOnce(&CudaBackendArray, &CudaBackendArray) -> Result<()>,
272            G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void) -> Result<()>
273    {
274        #[allow(unreachable_patterns)]
275        match (a, b) {
276            (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
277                f(a2, b2)?;
278                let inner_g = mutex_lock(&self.inner)?;
279                let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
280                    Some(tmp_kernel) => tmp_kernel,
281                    None => return Err(Error::NoKernel(String::from(kernel_name))),
282                };
283                if !Arc::ptr_eq(&a2.slice, &b2.slice) {
284                    let a_slice_g = mutex_lock(&a2.slice)?;
285                    let mut b_slice_g = mutex_lock(&b2.slice)?;
286                    g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?;
287                } else {
288                    let mut a_slice_g = mutex_lock(&a2.slice)?;
289                    g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?;
290                }
291                match inner_g.device.synchronize() {
292                    Ok(()) => (),
293                    Err(err) => return Err(Error::Cuda(err)),
294                }
295                Ok(())
296            },
297            _ => Err(Error::InvalidBackendArray),
298        }
299    }
300
301    fn check_and_launch3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
302        where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
303            G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void, *mut c_void) -> Result<()>
304    {
305        #[allow(unreachable_patterns)]
306        match (a, b, c) {
307            (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
308                f(a2, b2, c2)?;
309                let inner_g = mutex_lock(&self.inner)?;
310                let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
311                    Some(tmp_kernel) => tmp_kernel,
312                    None => return Err(Error::NoKernel(String::from(kernel_name))),
313                };
314                match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
315                    (false, false, false) => {
316                        let a_slice_g = mutex_lock(&a2.slice)?;
317                        let b_slice_g = mutex_lock(&b2.slice)?;
318                        let mut c_slice_g = mutex_lock(&c2.slice)?;
319                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
320                    },
321                    (true, false, false) => {
322                        let a_slice_g = mutex_lock(&a2.slice)?;
323                        let mut c_slice_g = mutex_lock(&c2.slice)?;
324                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*a_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
325                    },
326                    (false, true, false) => {
327                        let mut a_slice_g = mutex_lock(&a2.slice)?;
328                        let b_slice_g = mutex_lock(&b2.slice)?;
329                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
330                    },
331                    (false, false, true) => {
332                        let a_slice_g = mutex_lock(&a2.slice)?;
333                        let mut b_slice_g = mutex_lock(&b2.slice)?;
334                        g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?
335                    },
336                    _ => {
337                        let mut a_slice_g = mutex_lock(&a2.slice)?;
338                        g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
339                    },
340                }
341                match inner_g.device.synchronize() {
342                    Ok(()) => (),
343                    Err(err) => return Err(Error::Cuda(err)),
344                }
345                Ok(())
346            },
347            _ => Err(Error::InvalidBackendArray),
348        }
349    }    
350
351    fn check_and_launch_cublas3<F, G>(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
352        where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
353            G: FnOnce(&CudaInnerBackend, CUdeviceptr, CUdeviceptr, CUdeviceptr) -> Result<()>
354    {
355        #[allow(unreachable_patterns)]
356        match (a, b, c) {
357            (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
358                f(a2, b2, c2)?;
359                let inner_g = mutex_lock(&self.inner)?;
360                match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
361                    (false, false, false) => {
362                        let a_slice_g = mutex_lock(&a2.slice)?;
363                        let b_slice_g = mutex_lock(&b2.slice)?;
364                        let mut c_slice_g = mutex_lock(&c2.slice)?;
365                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
366                        let b_device_ptr = *(&(*b_slice_g)).device_ptr();
367                        let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
368                        g(&*inner_g, a_device_ptr, b_device_ptr, c_device_ptr)?
369                    },
370                    (true, false, false) => {
371                        let a_slice_g = mutex_lock(&a2.slice)?;
372                        let mut c_slice_g = mutex_lock(&c2.slice)?;
373                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
374                        let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
375                        g(&*inner_g, a_device_ptr, a_device_ptr, c_device_ptr)?
376                    },
377                    (false, true, false) => {
378                        let mut a_slice_g = mutex_lock(&a2.slice)?;
379                        let b_slice_g = mutex_lock(&b2.slice)?;
380                        let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
381                        let b_device_ptr = *(&(*b_slice_g)).device_ptr();
382                        g(&*inner_g, a_device_ptr, b_device_ptr, a_device_ptr)?
383                    },
384                    (false, false, true) => {
385                        let a_slice_g = mutex_lock(&a2.slice)?;
386                        let mut b_slice_g = mutex_lock(&b2.slice)?;
387                        let a_device_ptr = *(&(*a_slice_g)).device_ptr();
388                        let b_device_ptr = *(&mut (*b_slice_g)).device_ptr_mut();
389                        g(&*inner_g, a_device_ptr, b_device_ptr, b_device_ptr)?
390                    },
391                    _ => {
392                        let mut a_slice_g = mutex_lock(&a2.slice)?;
393                        let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
394                        g(&*inner_g, a_device_ptr, a_device_ptr, a_device_ptr)?
395                    },
396                }
397                match inner_g.device.synchronize() {
398                    Ok(()) => (),
399                    Err(err) => return Err(Error::Cuda(err)),
400                }
401                Ok(())
402            },
403            _ => Err(Error::InvalidBackendArray),
404        }
405    }
406    
407    fn check_and_launch_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
408    {
409        let is_mma = self.has_mma;
410        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
411                if a2.len != n * m {
412                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
413                }
414                if b2.len != n * m {
415                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
416                }
417                Ok(())
418        }, |_, kernel, a_param, b_param| {
419                let config = preferred_launch_config(n, m, false, is_mma);
420                let mut params = vec![
421                    a_param,
422                    b_param,
423                    n.as_kernel_param(),
424                    m.as_kernel_param()
425                ];
426                unsafe {
427                    match kernel.launch(config, &mut params) {
428                        Ok(()) => Ok(()),
429                        Err(err) => Err(Error::Cuda(err)),
430                    }
431                }
432        })
433    }
434
435    fn check_and_launch_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
436    {
437        let is_mma = self.has_mma;
438        self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
439                if a2.len != n * m {
440                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
441                }
442                if b2.len != n * m {
443                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
444                }
445                if c2.len != n * m {
446                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
447                }
448                Ok(())
449        }, |_, kernel, a_param, b_param, c_param| {
450                let config = preferred_launch_config(n, m, false, is_mma);
451                let mut params = vec![
452                    a_param,
453                    b_param,
454                    c_param,
455                    n.as_kernel_param(),
456                    m.as_kernel_param()
457                ];
458                unsafe {
459                    match kernel.launch(config, &mut params) {
460                        Ok(()) => Ok(()),
461                        Err(err) => Err(Error::Cuda(err)),
462                    }
463                }
464        })
465    }
466
467    fn check_and_launch_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
468    {
469        let is_mma = self.has_mma;
470        self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
471                if a2.len != n * l {
472                    return Err(Error::BackendArrayElemCount(a2.len, n * l));
473                }
474                if b2.len != l * m {
475                    return Err(Error::BackendArrayElemCount(b2.len, l * m));
476                }
477                if c2.len != n * m {
478                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
479                }
480                Ok(())
481        }, |_, kernel, a_param, b_param, c_param| {
482                let config = preferred_launch_config(n, m, true, is_mma);
483                let mut params = vec![
484                    a_param,
485                    b_param,
486                    c_param,
487                    n.as_kernel_param(),
488                    m.as_kernel_param(),
489                    l.as_kernel_param()
490                ];
491                unsafe {
492                    match kernel.launch(config, &mut params) {
493                        Ok(()) => Ok(()),
494                        Err(err) => Err(Error::Cuda(err)),
495                    }
496                }
497        })
498    }
499
500    fn check_and_launch_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
501    {
502        let is_mma = self.has_mma;
503        self.check_and_launch2(kernel_name, a, c, |a2, c2| {
504                if a2.len != n * m  {
505                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
506                }
507                if c2.len != n * m {
508                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
509                }
510                Ok(())
511        }, |_, kernel, a_param, c_param| {
512                let config = preferred_launch_config(n, m, false, is_mma);
513                let mut params = vec![
514                    a_param,
515                    b.as_kernel_param(),
516                    c_param,
517                    n.as_kernel_param(),
518                    m.as_kernel_param()
519                ];
520                unsafe {
521                    match kernel.launch(config, &mut params) {
522                        Ok(()) => Ok(()),
523                        Err(err) => Err(Error::Cuda(err)),
524                    }
525                }
526        })
527    }
528
529    fn check_and_launch_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
530    {
531        let is_mma = self.has_mma;
532        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
533                if a2.len != n * m {
534                    return Err(Error::BackendArrayElemCount(a2.len, n * m));
535                }
536                if b2.len != n * m {
537                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
538                }
539                Ok(())
540        }, |_, kernel, a_param, b_param| {
541                let config = preferred_launch_config(n, m, false, is_mma);
542                let mut params = vec![
543                    a_param,
544                    b_param,
545                    n.as_kernel_param(),
546                    m.as_kernel_param(),
547                    ((config.block_dim.1) as usize).as_kernel_param(),
548                    ((config.block_dim.0) as usize).as_kernel_param()
549                ];
550                unsafe {
551                    match kernel.launch(config, &mut params) {
552                        Ok(()) => Ok(()),
553                        Err(err) => Err(Error::Cuda(err)),
554                    }
555                }
556        })
557    }
558
559    fn check_and_launch_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
560    {
561        let is_mma = self.has_mma;
562        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
563                if a2.len != n {
564                    return Err(Error::BackendArrayElemCount(a2.len, n));
565                }
566                if b2.len != n * m {
567                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
568                }
569                Ok(())
570        }, |_, kernel, a_param, b_param| {
571                let config = preferred_launch_config(n, m, false, is_mma);
572                let mut params = vec![
573                    a_param,
574                    b_param,
575                    n.as_kernel_param(),
576                    m.as_kernel_param()
577                ];
578                unsafe {
579                    match kernel.launch(config, &mut params) {
580                        Ok(()) => Ok(()),
581                        Err(err) => Err(Error::Cuda(err)),
582                    }
583                }
584        })
585    }
586
587    fn check_and_launch_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
588    {
589        let is_mma = self.has_mma;
590        self.check_and_launch2(kernel_name, a, b, |a2, b2| {
591                if a2.len != m {
592                    return Err(Error::BackendArrayElemCount(a2.len, m));
593                }
594                if b2.len != n * m {
595                    return Err(Error::BackendArrayElemCount(b2.len, n * m));
596                }
597                Ok(())
598        }, |_, kernel, a_param, b_param| {
599                let config = preferred_launch_config(n, m, false, is_mma);
600                let mut params = vec![
601                    a_param,
602                    b_param,
603                    n.as_kernel_param(),
604                    m.as_kernel_param()
605                ];
606                unsafe {
607                    match kernel.launch(config, &mut params) {
608                        Ok(()) => Ok(()),
609                        Err(err) => Err(Error::Cuda(err)),
610                    }
611                }
612        })
613    }    
614    
615    fn check_and_launch_cublas_for_mul(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize, is_trans_a: bool, is_trans_b: bool) -> Result<()>
616    {
617        self.check_and_launch_cublas3(a, b, c, |a2, b2, c2| {
618                if a2.len != n * l {
619                    return Err(Error::BackendArrayElemCount(a2.len, n * l));
620                }
621                if b2.len != l * m {
622                    return Err(Error::BackendArrayElemCount(b2.len, l * m));
623                }
624                if c2.len != n * m {
625                    return Err(Error::BackendArrayElemCount(c2.len, n * m));
626                }
627                Ok(())
628        }, |inner, a_device_ptr, b_device_ptr, c_device_ptr| {
629                unsafe {
630                    match &inner.cublas {
631                        Some(cublas) => {
632                            let (transa, lda) = if is_trans_a {
633                                (cublasOperation_t::CUBLAS_OP_T, n as c_int)
634                            } else {
635                                (cublasOperation_t::CUBLAS_OP_N, l as c_int)
636                            };
637                            let (transb, ldb) = if is_trans_b {
638                                (cublasOperation_t::CUBLAS_OP_T, l as c_int)
639                            } else {
640                                (cublasOperation_t::CUBLAS_OP_N, m as c_int)
641                            };
642                            let alpha = 1.0f32;
643                            let beta = 0.0f32;
644                            let res = sgemm(*cublas.handle(),
645                                transb, transa,
646                                m as c_int, n as c_int, l as c_int,
647                                (&alpha) as *const _,
648                                b_device_ptr as *const _, ldb,
649                                a_device_ptr as *const _, lda,
650                                (&beta) as *const _,
651                                c_device_ptr as *mut _, m as c_int);
652                            match res {
653                                Ok(()) => Ok(()),
654                                Err(err) => Err(Error::Cublas(err)),
655                            }
656                        },
657                        None => Err(Error::NoCublas),
658                    }
659                }
660        })
661    }
662}
663
664impl Backend for CudaBackend
665{
666    fn name(&self) -> &'static str
667    {
668        if self.has_cublas {
669            "CUDA(cuBLAS)"
670        } else if self.has_mma {
671            "CUDA(mma)"
672        } else {
673            "CUDA"
674        }
675    }
676    
677    fn has_cublas(&self) -> bool
678    { self.has_cublas }
679
680    unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
681    {
682        let inner_g = mutex_lock(&self.inner)?;
683        let slice: CudaSlice<f32> = match inner_g.device.alloc(n) {
684            Ok(tmp_slice) => tmp_slice,
685            Err(err) => return Err(Error::Cuda(err)),
686        };
687        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
688        Ok(BackendArray::Cuda(cuda_array))
689    }
690
691    fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
692    {
693        let inner_g = mutex_lock(&self.inner)?;
694        let slice: CudaSlice<f32> = match inner_g.device.alloc_zeros(n) {
695            Ok(tmp_slice) => tmp_slice,
696            Err(err) => return Err(Error::Cuda(err)),
697        };
698        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
699        Ok(BackendArray::Cuda(cuda_array))
700    }
701    
702    fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
703    {
704        let inner_g = mutex_lock(&self.inner)?;
705        let slice: CudaSlice<f32> = match inner_g.device.htod_sync_copy(elems) {
706            Ok(tmp_slice) => tmp_slice,
707            Err(err) => return Err(Error::Cuda(err)),
708        };
709        let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: elems.len(), };
710        Ok(BackendArray::Cuda(cuda_array))
711    }
712    
713    fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
714    {
715        #[allow(unreachable_patterns)]
716        match a {
717            BackendArray::Cuda(a2) => {
718                if a2.len != elems.len() {
719                    return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
720                }
721                let inner_g = mutex_lock(&self.inner)?;
722                let a_slice_g = mutex_lock(&a2.slice)?;
723                match inner_g.device.dtoh_sync_copy_into(&(*a_slice_g), elems) {
724                    Ok(()) => (),
725                    Err(err) => return Err(Error::Cuda(err)),
726                }
727            },
728            _ => return Err(Error::InvalidBackendArray),
729        }
730        Ok(())
731    }
732
733    fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
734    {
735        #[allow(unreachable_patterns)]
736        match a {
737            BackendArray::Cuda(a2) => {
738                if a2.len != elems.len() {
739                    return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
740                }
741                let inner_g = mutex_lock(&self.inner)?;
742                let mut a_slice_g = mutex_lock(&a2.slice)?;
743                match inner_g.device.htod_sync_copy_into(elems, &mut (*a_slice_g)) {
744                    Ok(()) => (),
745                    Err(err) => return Err(Error::Cuda(err)),
746                }
747            },
748            _ => return Err(Error::InvalidBackendArray),
749        }
750        Ok(())
751    }
752    
753    fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
754    {
755        #[allow(unreachable_patterns)]
756        match (a, b) {
757            (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
758                if Arc::ptr_eq(&a2.slice, &b2.slice) {
759                    return Ok(());
760                }
761                if a2.len != b2.len {
762                    return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
763                }
764                let inner_g = mutex_lock(&self.inner)?;
765                let a_slice_g = mutex_lock(&a2.slice)?;
766                let mut b_slice_g = mutex_lock(&b2.slice)?;
767                match inner_g.device.dtod_copy(&(*a_slice_g), &mut (*b_slice_g)) {
768                    Ok(()) => (),
769                    Err(err) => return Err(Error::Cuda(err)),
770                }
771                match inner_g.device.synchronize() {
772                    Ok(()) => (),
773                    Err(err) => return Err(Error::Cuda(err)),
774                }
775            },
776            _ => return Err(Error::InvalidBackendArray),
777        }
778        Ok(())
779    }
780
781    fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
782    { self.check_and_launch_for_fun("transpose_a", a, b, n, m) }
783
784    fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
785    { self.check_and_launch_for_op("add_a_b", a, b, c, n, m) }
786
787    fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
788    { self.check_and_launch_for_op("add_at_b", a, b, c, n, m) }
789    
790    fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
791    { self.check_and_launch_for_op("add_a_bt", a, b, c, n, m) }
792
793    fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
794    { self.check_and_launch_for_op("add_at_bt", a, b, c, n, m) }
795
796    fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
797    { self.check_and_launch_for_op("sub_a_b", a, b, c, n, m) }
798
799    fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
800    { self.check_and_launch_for_op("sub_at_b", a, b, c, n, m) }
801    
802    fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
803    { self.check_and_launch_for_op("sub_a_bt", a, b, c, n, m) }
804
805    fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>    
806    { self.check_and_launch_for_op("sub_at_bt", a, b, c, n, m) }
807    
808    fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
809    {
810        if self.has_cublas {
811            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, false)
812        } else {
813            self.check_and_launch_for_mul("mul_a_b", a, b, c, n, m, l)
814        }
815    }
816
817    fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
818    {
819        if self.has_cublas {
820            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, false)
821        } else {
822            self.check_and_launch_for_mul("mul_at_b", a, b, c, n, m, l)
823        }
824    }
825
826    fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
827    {
828        if self.has_cublas {
829            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, true)
830        } else {
831            self.check_and_launch_for_mul("mul_a_bt", a, b, c, n, m, l) 
832        }
833    }
834
835    fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
836    {
837        if self.has_cublas {
838            self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, true)
839        } else {
840            self.check_and_launch_for_mul("mul_at_bt", a, b, c, n, m, l)
841        }
842    }
843
844    fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
845    { self.check_and_launch_for_op("mul_a_b_for_elems", a, b, c, n, m) }
846
847    fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
848    { self.check_and_launch_for_op("mul_at_b_for_elems", a, b, c, n, m) }
849    
850    fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
851    { self.check_and_launch_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
852    
853    fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
854    { self.check_and_launch_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
855
856    fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
857    { self.check_and_launch_for_op("div_a_b_for_elems", a, b, c, n, m) }
858
859    fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
860    { self.check_and_launch_for_op("div_at_b_for_elems", a, b, c, n, m) }
861    
862    fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
863    { self.check_and_launch_for_op("div_a_bt_for_elems", a, b, c, n, m) }
864    
865    fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
866    { self.check_and_launch_for_op("div_at_bt_for_elems", a, b, c, n, m) }
867
868    fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
869    { self.check_and_launch_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
870
871    fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
872    { self.check_and_launch_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
873
874    fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
875    { self.check_and_launch_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
876
877    fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
878    { self.check_and_launch_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
879
880    fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
881    { self.check_and_launch_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
882
883    fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
884    { self.check_and_launch_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
885    
886    fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
887    { self.check_and_launch_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
888
889    fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
890    { self.check_and_launch_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
891
892    fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
893    { self.check_and_launch_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
894
895    fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
896    { self.check_and_launch_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
897
898    fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
899    { self.check_and_launch_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
900
901    fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
902    { self.check_and_launch_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
903
904    fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
905    { self.check_and_launch_for_fun("sigmoid_a", a, b, n, m) }
906
907    fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
908    { self.check_and_launch_for_fun("sigmoid_at", a, b, n, m) }
909
910    fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
911    { self.check_and_launch_for_fun("tanh_a", a, b, n, m) }
912
913    fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
914    { self.check_and_launch_for_fun("tanh_at", a, b, n, m) }
915
916    fn swish_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
917    { self.check_and_launch_for_fun("swish_a", a, b, n, m) }
918
919    fn swish_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
920    { self.check_and_launch_for_fun("swish_at", a, b, n, m) }
921
922    fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
923    { self.check_and_launch_for_fun_and_tiles("softmax_a", a, b, n, m) }
924
925    fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
926    { self.check_and_launch_for_fun_and_tiles("softmax_at", a, b, n, m) }
927
928    fn sqrt_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
929    { self.check_and_launch_for_fun("sqrt_a", a, b, n, m) }
930
931    fn sqrt_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
932    { self.check_and_launch_for_fun("sqrt_at", a, b, n, m) }
933
934    fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
935    { self.check_and_launch_for_repeat_col("repeat_col_a", a, b, n, m) }
936
937    fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
938    { self.check_and_launch_for_repeat_row("repeat_row_a", a, b, n, m) }
939
940    fn abs_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
941    { self.check_and_launch_for_fun("abs_a", a, b, n, m) }
942
943    fn abs_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
944    { self.check_and_launch_for_fun("abs_at", a, b, n, m) }
945
946    fn pow_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
947    { self.check_and_launch_for_op("pow_a_b", a, b, c, n, m) }
948
949    fn pow_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
950    { self.check_and_launch_for_op("pow_at_b", a, b, c, n, m) }
951    
952    fn pow_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
953    { self.check_and_launch_for_op("pow_a_bt", a, b, c, n, m) }
954    
955    fn pow_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
956    { self.check_and_launch_for_op("pow_at_bt", a, b, c, n, m) }
957
958    fn pow_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
959    { self.check_and_launch_for_scalar("pow_a_b_for_scalar", a, b, c, n, m) }
960
961    fn pow_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
962    { self.check_and_launch_for_scalar("pow_at_b_for_scalar", a, b, c, n, m) }
963
964    fn rpow_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
965    { self.check_and_launch_for_scalar("rpow_a_b_for_scalar", a, b, c, n, m) }
966
967    fn rpow_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
968    { self.check_and_launch_for_scalar("rpow_at_b_for_scalar", a, b, c, n, m) }
969
970    fn exp_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
971    { self.check_and_launch_for_fun("exp_a", a, b, n, m) }
972
973    fn exp_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
974    { self.check_and_launch_for_fun("exp_at", a, b, n, m) }
975
976    fn ln_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
977    { self.check_and_launch_for_fun("ln_a", a, b, n, m) }
978
979    fn ln_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
980    { self.check_and_launch_for_fun("ln_at", a, b, n, m) }
981
982    fn log2_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
983    { self.check_and_launch_for_fun("log2_a", a, b, n, m) }
984
985    fn log2_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
986    { self.check_and_launch_for_fun("log2_at", a, b, n, m) }
987
988    fn log10_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
989    { self.check_and_launch_for_fun("log10_a", a, b, n, m) }
990
991    fn log10_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
992    { self.check_and_launch_for_fun("log10_at", a, b, n, m) }
993
994    fn sin_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
995    { self.check_and_launch_for_fun("sin_a", a, b, n, m) }
996
997    fn sin_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
998    { self.check_and_launch_for_fun("sin_at", a, b, n, m) }
999
1000    fn cos_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1001    { self.check_and_launch_for_fun("cos_a", a, b, n, m) }
1002
1003    fn cos_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1004    { self.check_and_launch_for_fun("cos_at", a, b, n, m) }
1005
1006    fn tan_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1007    { self.check_and_launch_for_fun("tan_a", a, b, n, m) }
1008
1009    fn tan_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1010    { self.check_and_launch_for_fun("tan_at", a, b, n, m) }
1011
1012    fn asin_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1013    { self.check_and_launch_for_fun("asin_a", a, b, n, m) }
1014
1015    fn asin_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1016    { self.check_and_launch_for_fun("asin_at", a, b, n, m) }
1017
1018    fn acos_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1019    { self.check_and_launch_for_fun("acos_a", a, b, n, m) }
1020
1021    fn acos_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1022    { self.check_and_launch_for_fun("acos_at", a, b, n, m) }
1023
1024    fn atan_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1025    { self.check_and_launch_for_fun("atan_a", a, b, n, m) }
1026
1027    fn atan_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1028    { self.check_and_launch_for_fun("atan_at", a, b, n, m) }
1029
1030    fn atan2_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1031    { self.check_and_launch_for_op("atan2_a_b", a, b, c, n, m) }
1032
1033    fn atan2_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1034    { self.check_and_launch_for_op("atan2_at_b", a, b, c, n, m) }
1035    
1036    fn atan2_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1037    { self.check_and_launch_for_op("atan2_a_bt", a, b, c, n, m) }
1038    
1039    fn atan2_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1040    { self.check_and_launch_for_op("atan2_at_bt", a, b, c, n, m) }
1041
1042    fn atan2_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1043    { self.check_and_launch_for_scalar("atan2_a_b_for_scalar", a, b, c, n, m) }
1044
1045    fn atan2_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1046    { self.check_and_launch_for_scalar("atan2_at_b_for_scalar", a, b, c, n, m) }
1047
1048    fn ratan2_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1049    { self.check_and_launch_for_scalar("ratan2_a_b_for_scalar", a, b, c, n, m) }
1050
1051    fn ratan2_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1052    { self.check_and_launch_for_scalar("ratan2_at_b_for_scalar", a, b, c, n, m) }
1053
1054    fn sinh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1055    { self.check_and_launch_for_fun("sinh_a", a, b, n, m) }
1056
1057    fn sinh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1058    { self.check_and_launch_for_fun("sinh_at", a, b, n, m) }
1059
1060    fn cosh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1061    { self.check_and_launch_for_fun("cosh_a", a, b, n, m) }
1062
1063    fn cosh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1064    { self.check_and_launch_for_fun("cosh_at", a, b, n, m) }
1065
1066    fn asinh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1067    { self.check_and_launch_for_fun("asinh_a", a, b, n, m) }
1068
1069    fn asinh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1070    { self.check_and_launch_for_fun("asinh_at", a, b, n, m) }
1071
1072    fn acosh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1073    { self.check_and_launch_for_fun("acosh_a", a, b, n, m) }
1074
1075    fn acosh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1076    { self.check_and_launch_for_fun("acosh_at", a, b, n, m) }
1077
1078    fn atanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1079    { self.check_and_launch_for_fun("atanh_a", a, b, n, m) }
1080
1081    fn atanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1082    { self.check_and_launch_for_fun("atanh_at", a, b, n, m) }
1083
1084    fn signum_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1085    { self.check_and_launch_for_fun("signum_a", a, b, n, m) }
1086
1087    fn signum_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1088    { self.check_and_launch_for_fun("signum_at", a, b, n, m) }
1089
1090    fn ceil_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1091    { self.check_and_launch_for_fun("ceil_a", a, b, n, m) }
1092
1093    fn ceil_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1094    { self.check_and_launch_for_fun("ceil_at", a, b, n, m) }
1095
1096    fn floor_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1097    { self.check_and_launch_for_fun("floor_a", a, b, n, m) }
1098
1099    fn floor_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1100    { self.check_and_launch_for_fun("floor_at", a, b, n, m) }
1101
1102    fn round_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1103    { self.check_and_launch_for_fun("round_a", a, b, n, m) }
1104
1105    fn round_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1106    { self.check_and_launch_for_fun("round_at", a, b, n, m) }
1107
1108    fn trunc_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1109    { self.check_and_launch_for_fun("trunc_a", a, b, n, m) }
1110
1111    fn trunc_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1112    { self.check_and_launch_for_fun("trunc_at", a, b, n, m) }
1113}
1114
1115#[cfg(test)]
1116mod tests;