1use std::default::Default;
10use std::ffi::c_int;
11use std::ffi::c_void;
12use std::sync::Arc;
13use std::sync::Mutex;
14use crate::Backend;
15use crate::BackendArray;
16use crate::Error;
17use crate::Result;
18use crate::mutex_lock;
19
20pub use cudarc::cublas::result::CublasError;
21pub use cudarc::driver::DriverError;
22
23use cudarc::cublas::result::sgemm;
24use cudarc::cublas::sys::cublasOperation_t;
25use cudarc::cublas::CudaBlas;
26use cudarc::driver::sys::CUdeviceptr;
27use cudarc::driver::CudaDevice;
28use cudarc::driver::CudaFunction;
29use cudarc::driver::CudaSlice;
30use cudarc::driver::DeviceRepr;
31use cudarc::driver::DevicePtr;
32use cudarc::driver::DevicePtrMut;
33use cudarc::driver::LaunchAsync;
34use cudarc::driver::LaunchConfig;
35use cudarc::nvrtc::CompileError;
36use cudarc::nvrtc::CompileOptions;
37use cudarc::nvrtc::compile_ptx_with_opts;
38
39const SOURCE: &'static str = include_str!("cuda.cu");
40
41const KERNELS: &'static [&'static str] = &[
42 "transpose_a",
43 "add_a_b",
44 "add_at_b",
45 "add_a_bt",
46 "add_at_bt",
47 "sub_a_b",
48 "sub_at_b",
49 "sub_a_bt",
50 "sub_at_bt",
51 "mul_a_b",
52 "mul_at_b",
53 "mul_a_bt",
54 "mul_at_bt",
55 "mul_a_b_for_elems",
56 "mul_at_b_for_elems",
57 "mul_a_bt_for_elems",
58 "mul_at_bt_for_elems",
59 "div_a_b_for_elems",
60 "div_at_b_for_elems",
61 "div_a_bt_for_elems",
62 "div_at_bt_for_elems",
63 "add_a_b_for_scalar",
64 "add_at_b_for_scalar",
65 "sub_a_b_for_scalar",
66 "sub_at_b_for_scalar",
67 "rsub_a_b_for_scalar",
68 "rsub_at_b_for_scalar",
69 "mul_a_b_for_scalar",
70 "mul_at_b_for_scalar",
71 "div_a_b_for_scalar",
72 "div_at_b_for_scalar",
73 "rdiv_a_b_for_scalar",
74 "rdiv_at_b_for_scalar",
75 "sigmoid_a",
76 "sigmoid_at",
77 "tanh_a",
78 "tanh_at",
79 "swish_a",
80 "swish_at",
81 "softmax_a",
82 "softmax_at",
83 "sqrt_a",
84 "sqrt_at",
85 "repeat_col_a",
86 "repeat_row_a",
87 "abs_a",
88 "abs_at",
89 "pow_a_b",
90 "pow_at_b",
91 "pow_a_bt",
92 "pow_at_bt",
93 "pow_a_b_for_scalar",
94 "pow_at_b_for_scalar",
95 "rpow_a_b_for_scalar",
96 "rpow_at_b_for_scalar",
97 "exp_a",
98 "exp_at",
99 "ln_a",
100 "ln_at",
101 "log2_a",
102 "log2_at",
103 "log10_a",
104 "log10_at",
105 "sin_a",
106 "sin_at",
107 "cos_a",
108 "cos_at",
109 "tan_a",
110 "tan_at",
111 "asin_a",
112 "asin_at",
113 "acos_a",
114 "acos_at",
115 "atan_a",
116 "atan_at",
117 "atan2_a_b",
118 "atan2_at_b",
119 "atan2_a_bt",
120 "atan2_at_bt",
121 "atan2_a_b_for_scalar",
122 "atan2_at_b_for_scalar",
123 "ratan2_a_b_for_scalar",
124 "ratan2_at_b_for_scalar",
125 "sinh_a",
126 "sinh_at",
127 "cosh_a",
128 "cosh_at",
129 "asinh_a",
130 "asinh_at",
131 "acosh_a",
132 "acosh_at",
133 "atanh_a",
134 "atanh_at",
135 "signum_a",
136 "signum_at",
137 "ceil_a",
138 "ceil_at",
139 "floor_a",
140 "floor_at",
141 "round_a",
142 "round_at",
143 "trunc_a",
144 "trunc_at"
145];
146
147#[derive(Debug)]
151pub struct CudaBackendArray
152{
153 slice: Arc<Mutex<CudaSlice<f32>>>,
154 len: usize,
155}
156
157struct CudaInnerBackend
158{
159 device: Arc<CudaDevice>,
160 cublas: Option<CudaBlas>,
161}
162
163pub struct CudaBackend
165{
166 inner: Mutex<CudaInnerBackend>,
167 has_cublas: bool,
168 has_mma: bool,
169}
170
171fn preferred_launch_config(n: usize, m: usize, is_mul: bool, is_mma: bool) -> LaunchConfig
172{
173 if m == 1 && !is_mul {
174 let n2 = ((n + 1023) / 1024) as u32;
175 LaunchConfig {
176 grid_dim: (n2, 1, 1),
177 block_dim: (1024, 1, 1),
178 shared_mem_bytes: 0,
179 }
180 } else if n == 1 && !is_mul {
181 let m2 = ((m + 1023) / 1024) as u32;
182 LaunchConfig {
183 grid_dim: (1, m2, 1),
184 block_dim: (1, 1024, 1),
185 shared_mem_bytes: 0,
186 }
187 } else if is_mul {
188 if is_mma {
189 let n2 = ((n + 63) / 64) as u32;
190 let m2 = ((m + 63) / 64) as u32;
191 LaunchConfig {
192 grid_dim: (n2, m2, 1),
193 block_dim: (1024, 1, 1),
194 shared_mem_bytes: 0,
195 }
196 } else {
197 let n2 = (((n + 3) / 4 + 15) / 16) as u32;
198 let m2 = (((m + 3) / 4 + 15) / 16) as u32;
199 LaunchConfig {
200 grid_dim: (n2, m2, 1),
201 block_dim: (16, 16, 1),
202 shared_mem_bytes: 0,
203 }
204 }
205 } else {
206 let n2 = ((n + 31) / 32) as u32;
207 let m2 = ((m + 31) / 32) as u32;
208 LaunchConfig {
209 grid_dim: (n2, m2, 1),
210 block_dim: (32, 32, 1),
211 shared_mem_bytes: 0,
212 }
213 }
214}
215
216impl CudaBackend
217{
218 pub fn new() -> Result<CudaBackend>
220 {
221 if cfg!(feature = "default_cublas") {
222 Self::new_with_ordinal_and_flags(0, true, false)
223 } else if cfg!(feature = "default_mma") {
224 Self::new_with_ordinal_and_flags(0, false, true)
225 } else {
226 Self::new_with_ordinal_and_flags(0, false, false)
227 }
228 }
229
230 pub fn new_with_ordinal_and_flags(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<CudaBackend>
237 {
238 let device = match CudaDevice::new(ordinal) {
239 Ok(tmp_device) => tmp_device,
240 Err(err) => return Err(Error::Cuda(err)),
241 };
242 let mut options: CompileOptions = Default::default();
243 if is_mma {
244 options.options = vec![String::from("-DUNMTX_GPU_MMA=1")];
245 options.arch = Some("sm_80");
246 }
247 let ptx = match compile_ptx_with_opts(SOURCE, options) {
248 Ok(tmp_ptx) => tmp_ptx,
249 Err(CompileError::CompileError { log, .. }) => return Err(Error::Compilation(log.as_c_str().to_string_lossy().into_owned())),
250 Err(err) => return Err(Error::Compilation(format!("{}", err))),
251 };
252 match device.load_ptx(ptx, "unmtx_gpu", KERNELS) {
253 Ok(()) => (),
254 Err(err) => return Err(Error::Cuda(err)),
255 }
256 let cublas = if is_cublas {
257 match CudaBlas::new(device.clone()) {
258 Ok(tmp_cublas) => Some(tmp_cublas),
259 Err(err) => return Err(Error::Cublas(err)),
260 }
261 } else {
262 None
263 };
264 Ok(CudaBackend { inner: Mutex::new(CudaInnerBackend { device, cublas, }), has_cublas: is_cublas, has_mma: is_mma, })
265 }
266
267 pub fn has_cublas(&self) -> bool
268 { self.has_cublas }
269
270 fn check_and_launch2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
271 where F: FnOnce(&CudaBackendArray, &CudaBackendArray) -> Result<()>,
272 G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void) -> Result<()>
273 {
274 #[allow(unreachable_patterns)]
275 match (a, b) {
276 (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
277 f(a2, b2)?;
278 let inner_g = mutex_lock(&self.inner)?;
279 let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
280 Some(tmp_kernel) => tmp_kernel,
281 None => return Err(Error::NoKernel(String::from(kernel_name))),
282 };
283 if !Arc::ptr_eq(&a2.slice, &b2.slice) {
284 let a_slice_g = mutex_lock(&a2.slice)?;
285 let mut b_slice_g = mutex_lock(&b2.slice)?;
286 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?;
287 } else {
288 let mut a_slice_g = mutex_lock(&a2.slice)?;
289 g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?;
290 }
291 match inner_g.device.synchronize() {
292 Ok(()) => (),
293 Err(err) => return Err(Error::Cuda(err)),
294 }
295 Ok(())
296 },
297 _ => Err(Error::InvalidBackendArray),
298 }
299 }
300
301 fn check_and_launch3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
302 where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
303 G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void, *mut c_void) -> Result<()>
304 {
305 #[allow(unreachable_patterns)]
306 match (a, b, c) {
307 (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
308 f(a2, b2, c2)?;
309 let inner_g = mutex_lock(&self.inner)?;
310 let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
311 Some(tmp_kernel) => tmp_kernel,
312 None => return Err(Error::NoKernel(String::from(kernel_name))),
313 };
314 match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
315 (false, false, false) => {
316 let a_slice_g = mutex_lock(&a2.slice)?;
317 let b_slice_g = mutex_lock(&b2.slice)?;
318 let mut c_slice_g = mutex_lock(&c2.slice)?;
319 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
320 },
321 (true, false, false) => {
322 let a_slice_g = mutex_lock(&a2.slice)?;
323 let mut c_slice_g = mutex_lock(&c2.slice)?;
324 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*a_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
325 },
326 (false, true, false) => {
327 let mut a_slice_g = mutex_lock(&a2.slice)?;
328 let b_slice_g = mutex_lock(&b2.slice)?;
329 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
330 },
331 (false, false, true) => {
332 let a_slice_g = mutex_lock(&a2.slice)?;
333 let mut b_slice_g = mutex_lock(&b2.slice)?;
334 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?
335 },
336 _ => {
337 let mut a_slice_g = mutex_lock(&a2.slice)?;
338 g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
339 },
340 }
341 match inner_g.device.synchronize() {
342 Ok(()) => (),
343 Err(err) => return Err(Error::Cuda(err)),
344 }
345 Ok(())
346 },
347 _ => Err(Error::InvalidBackendArray),
348 }
349 }
350
351 fn check_and_launch_cublas3<F, G>(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
352 where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
353 G: FnOnce(&CudaInnerBackend, CUdeviceptr, CUdeviceptr, CUdeviceptr) -> Result<()>
354 {
355 #[allow(unreachable_patterns)]
356 match (a, b, c) {
357 (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
358 f(a2, b2, c2)?;
359 let inner_g = mutex_lock(&self.inner)?;
360 match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
361 (false, false, false) => {
362 let a_slice_g = mutex_lock(&a2.slice)?;
363 let b_slice_g = mutex_lock(&b2.slice)?;
364 let mut c_slice_g = mutex_lock(&c2.slice)?;
365 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
366 let b_device_ptr = *(&(*b_slice_g)).device_ptr();
367 let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
368 g(&*inner_g, a_device_ptr, b_device_ptr, c_device_ptr)?
369 },
370 (true, false, false) => {
371 let a_slice_g = mutex_lock(&a2.slice)?;
372 let mut c_slice_g = mutex_lock(&c2.slice)?;
373 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
374 let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
375 g(&*inner_g, a_device_ptr, a_device_ptr, c_device_ptr)?
376 },
377 (false, true, false) => {
378 let mut a_slice_g = mutex_lock(&a2.slice)?;
379 let b_slice_g = mutex_lock(&b2.slice)?;
380 let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
381 let b_device_ptr = *(&(*b_slice_g)).device_ptr();
382 g(&*inner_g, a_device_ptr, b_device_ptr, a_device_ptr)?
383 },
384 (false, false, true) => {
385 let a_slice_g = mutex_lock(&a2.slice)?;
386 let mut b_slice_g = mutex_lock(&b2.slice)?;
387 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
388 let b_device_ptr = *(&mut (*b_slice_g)).device_ptr_mut();
389 g(&*inner_g, a_device_ptr, b_device_ptr, b_device_ptr)?
390 },
391 _ => {
392 let mut a_slice_g = mutex_lock(&a2.slice)?;
393 let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
394 g(&*inner_g, a_device_ptr, a_device_ptr, a_device_ptr)?
395 },
396 }
397 match inner_g.device.synchronize() {
398 Ok(()) => (),
399 Err(err) => return Err(Error::Cuda(err)),
400 }
401 Ok(())
402 },
403 _ => Err(Error::InvalidBackendArray),
404 }
405 }
406
407 fn check_and_launch_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
408 {
409 let is_mma = self.has_mma;
410 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
411 if a2.len != n * m {
412 return Err(Error::BackendArrayElemCount(a2.len, n * m));
413 }
414 if b2.len != n * m {
415 return Err(Error::BackendArrayElemCount(b2.len, n * m));
416 }
417 Ok(())
418 }, |_, kernel, a_param, b_param| {
419 let config = preferred_launch_config(n, m, false, is_mma);
420 let mut params = vec![
421 a_param,
422 b_param,
423 n.as_kernel_param(),
424 m.as_kernel_param()
425 ];
426 unsafe {
427 match kernel.launch(config, &mut params) {
428 Ok(()) => Ok(()),
429 Err(err) => Err(Error::Cuda(err)),
430 }
431 }
432 })
433 }
434
435 fn check_and_launch_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
436 {
437 let is_mma = self.has_mma;
438 self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
439 if a2.len != n * m {
440 return Err(Error::BackendArrayElemCount(a2.len, n * m));
441 }
442 if b2.len != n * m {
443 return Err(Error::BackendArrayElemCount(b2.len, n * m));
444 }
445 if c2.len != n * m {
446 return Err(Error::BackendArrayElemCount(c2.len, n * m));
447 }
448 Ok(())
449 }, |_, kernel, a_param, b_param, c_param| {
450 let config = preferred_launch_config(n, m, false, is_mma);
451 let mut params = vec![
452 a_param,
453 b_param,
454 c_param,
455 n.as_kernel_param(),
456 m.as_kernel_param()
457 ];
458 unsafe {
459 match kernel.launch(config, &mut params) {
460 Ok(()) => Ok(()),
461 Err(err) => Err(Error::Cuda(err)),
462 }
463 }
464 })
465 }
466
467 fn check_and_launch_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
468 {
469 let is_mma = self.has_mma;
470 self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
471 if a2.len != n * l {
472 return Err(Error::BackendArrayElemCount(a2.len, n * l));
473 }
474 if b2.len != l * m {
475 return Err(Error::BackendArrayElemCount(b2.len, l * m));
476 }
477 if c2.len != n * m {
478 return Err(Error::BackendArrayElemCount(c2.len, n * m));
479 }
480 Ok(())
481 }, |_, kernel, a_param, b_param, c_param| {
482 let config = preferred_launch_config(n, m, true, is_mma);
483 let mut params = vec![
484 a_param,
485 b_param,
486 c_param,
487 n.as_kernel_param(),
488 m.as_kernel_param(),
489 l.as_kernel_param()
490 ];
491 unsafe {
492 match kernel.launch(config, &mut params) {
493 Ok(()) => Ok(()),
494 Err(err) => Err(Error::Cuda(err)),
495 }
496 }
497 })
498 }
499
500 fn check_and_launch_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
501 {
502 let is_mma = self.has_mma;
503 self.check_and_launch2(kernel_name, a, c, |a2, c2| {
504 if a2.len != n * m {
505 return Err(Error::BackendArrayElemCount(a2.len, n * m));
506 }
507 if c2.len != n * m {
508 return Err(Error::BackendArrayElemCount(c2.len, n * m));
509 }
510 Ok(())
511 }, |_, kernel, a_param, c_param| {
512 let config = preferred_launch_config(n, m, false, is_mma);
513 let mut params = vec![
514 a_param,
515 b.as_kernel_param(),
516 c_param,
517 n.as_kernel_param(),
518 m.as_kernel_param()
519 ];
520 unsafe {
521 match kernel.launch(config, &mut params) {
522 Ok(()) => Ok(()),
523 Err(err) => Err(Error::Cuda(err)),
524 }
525 }
526 })
527 }
528
529 fn check_and_launch_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
530 {
531 let is_mma = self.has_mma;
532 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
533 if a2.len != n * m {
534 return Err(Error::BackendArrayElemCount(a2.len, n * m));
535 }
536 if b2.len != n * m {
537 return Err(Error::BackendArrayElemCount(b2.len, n * m));
538 }
539 Ok(())
540 }, |_, kernel, a_param, b_param| {
541 let config = preferred_launch_config(n, m, false, is_mma);
542 let mut params = vec![
543 a_param,
544 b_param,
545 n.as_kernel_param(),
546 m.as_kernel_param(),
547 ((config.block_dim.1) as usize).as_kernel_param(),
548 ((config.block_dim.0) as usize).as_kernel_param()
549 ];
550 unsafe {
551 match kernel.launch(config, &mut params) {
552 Ok(()) => Ok(()),
553 Err(err) => Err(Error::Cuda(err)),
554 }
555 }
556 })
557 }
558
559 fn check_and_launch_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
560 {
561 let is_mma = self.has_mma;
562 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
563 if a2.len != n {
564 return Err(Error::BackendArrayElemCount(a2.len, n));
565 }
566 if b2.len != n * m {
567 return Err(Error::BackendArrayElemCount(b2.len, n * m));
568 }
569 Ok(())
570 }, |_, kernel, a_param, b_param| {
571 let config = preferred_launch_config(n, m, false, is_mma);
572 let mut params = vec![
573 a_param,
574 b_param,
575 n.as_kernel_param(),
576 m.as_kernel_param()
577 ];
578 unsafe {
579 match kernel.launch(config, &mut params) {
580 Ok(()) => Ok(()),
581 Err(err) => Err(Error::Cuda(err)),
582 }
583 }
584 })
585 }
586
587 fn check_and_launch_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
588 {
589 let is_mma = self.has_mma;
590 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
591 if a2.len != m {
592 return Err(Error::BackendArrayElemCount(a2.len, m));
593 }
594 if b2.len != n * m {
595 return Err(Error::BackendArrayElemCount(b2.len, n * m));
596 }
597 Ok(())
598 }, |_, kernel, a_param, b_param| {
599 let config = preferred_launch_config(n, m, false, is_mma);
600 let mut params = vec![
601 a_param,
602 b_param,
603 n.as_kernel_param(),
604 m.as_kernel_param()
605 ];
606 unsafe {
607 match kernel.launch(config, &mut params) {
608 Ok(()) => Ok(()),
609 Err(err) => Err(Error::Cuda(err)),
610 }
611 }
612 })
613 }
614
615 fn check_and_launch_cublas_for_mul(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize, is_trans_a: bool, is_trans_b: bool) -> Result<()>
616 {
617 self.check_and_launch_cublas3(a, b, c, |a2, b2, c2| {
618 if a2.len != n * l {
619 return Err(Error::BackendArrayElemCount(a2.len, n * l));
620 }
621 if b2.len != l * m {
622 return Err(Error::BackendArrayElemCount(b2.len, l * m));
623 }
624 if c2.len != n * m {
625 return Err(Error::BackendArrayElemCount(c2.len, n * m));
626 }
627 Ok(())
628 }, |inner, a_device_ptr, b_device_ptr, c_device_ptr| {
629 unsafe {
630 match &inner.cublas {
631 Some(cublas) => {
632 let (transa, lda) = if is_trans_a {
633 (cublasOperation_t::CUBLAS_OP_T, n as c_int)
634 } else {
635 (cublasOperation_t::CUBLAS_OP_N, l as c_int)
636 };
637 let (transb, ldb) = if is_trans_b {
638 (cublasOperation_t::CUBLAS_OP_T, l as c_int)
639 } else {
640 (cublasOperation_t::CUBLAS_OP_N, m as c_int)
641 };
642 let alpha = 1.0f32;
643 let beta = 0.0f32;
644 let res = sgemm(*cublas.handle(),
645 transb, transa,
646 m as c_int, n as c_int, l as c_int,
647 (&alpha) as *const _,
648 b_device_ptr as *const _, ldb,
649 a_device_ptr as *const _, lda,
650 (&beta) as *const _,
651 c_device_ptr as *mut _, m as c_int);
652 match res {
653 Ok(()) => Ok(()),
654 Err(err) => Err(Error::Cublas(err)),
655 }
656 },
657 None => Err(Error::NoCublas),
658 }
659 }
660 })
661 }
662}
663
664impl Backend for CudaBackend
665{
666 fn name(&self) -> &'static str
667 {
668 if self.has_cublas {
669 "CUDA(cuBLAS)"
670 } else if self.has_mma {
671 "CUDA(mma)"
672 } else {
673 "CUDA"
674 }
675 }
676
677 fn has_cublas(&self) -> bool
678 { self.has_cublas }
679
680 unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
681 {
682 let inner_g = mutex_lock(&self.inner)?;
683 let slice: CudaSlice<f32> = match inner_g.device.alloc(n) {
684 Ok(tmp_slice) => tmp_slice,
685 Err(err) => return Err(Error::Cuda(err)),
686 };
687 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
688 Ok(BackendArray::Cuda(cuda_array))
689 }
690
691 fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
692 {
693 let inner_g = mutex_lock(&self.inner)?;
694 let slice: CudaSlice<f32> = match inner_g.device.alloc_zeros(n) {
695 Ok(tmp_slice) => tmp_slice,
696 Err(err) => return Err(Error::Cuda(err)),
697 };
698 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
699 Ok(BackendArray::Cuda(cuda_array))
700 }
701
702 fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
703 {
704 let inner_g = mutex_lock(&self.inner)?;
705 let slice: CudaSlice<f32> = match inner_g.device.htod_sync_copy(elems) {
706 Ok(tmp_slice) => tmp_slice,
707 Err(err) => return Err(Error::Cuda(err)),
708 };
709 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: elems.len(), };
710 Ok(BackendArray::Cuda(cuda_array))
711 }
712
713 fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
714 {
715 #[allow(unreachable_patterns)]
716 match a {
717 BackendArray::Cuda(a2) => {
718 if a2.len != elems.len() {
719 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
720 }
721 let inner_g = mutex_lock(&self.inner)?;
722 let a_slice_g = mutex_lock(&a2.slice)?;
723 match inner_g.device.dtoh_sync_copy_into(&(*a_slice_g), elems) {
724 Ok(()) => (),
725 Err(err) => return Err(Error::Cuda(err)),
726 }
727 },
728 _ => return Err(Error::InvalidBackendArray),
729 }
730 Ok(())
731 }
732
733 fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
734 {
735 #[allow(unreachable_patterns)]
736 match a {
737 BackendArray::Cuda(a2) => {
738 if a2.len != elems.len() {
739 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
740 }
741 let inner_g = mutex_lock(&self.inner)?;
742 let mut a_slice_g = mutex_lock(&a2.slice)?;
743 match inner_g.device.htod_sync_copy_into(elems, &mut (*a_slice_g)) {
744 Ok(()) => (),
745 Err(err) => return Err(Error::Cuda(err)),
746 }
747 },
748 _ => return Err(Error::InvalidBackendArray),
749 }
750 Ok(())
751 }
752
753 fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
754 {
755 #[allow(unreachable_patterns)]
756 match (a, b) {
757 (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
758 if Arc::ptr_eq(&a2.slice, &b2.slice) {
759 return Ok(());
760 }
761 if a2.len != b2.len {
762 return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
763 }
764 let inner_g = mutex_lock(&self.inner)?;
765 let a_slice_g = mutex_lock(&a2.slice)?;
766 let mut b_slice_g = mutex_lock(&b2.slice)?;
767 match inner_g.device.dtod_copy(&(*a_slice_g), &mut (*b_slice_g)) {
768 Ok(()) => (),
769 Err(err) => return Err(Error::Cuda(err)),
770 }
771 match inner_g.device.synchronize() {
772 Ok(()) => (),
773 Err(err) => return Err(Error::Cuda(err)),
774 }
775 },
776 _ => return Err(Error::InvalidBackendArray),
777 }
778 Ok(())
779 }
780
781 fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
782 { self.check_and_launch_for_fun("transpose_a", a, b, n, m) }
783
784 fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
785 { self.check_and_launch_for_op("add_a_b", a, b, c, n, m) }
786
787 fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
788 { self.check_and_launch_for_op("add_at_b", a, b, c, n, m) }
789
790 fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
791 { self.check_and_launch_for_op("add_a_bt", a, b, c, n, m) }
792
793 fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
794 { self.check_and_launch_for_op("add_at_bt", a, b, c, n, m) }
795
796 fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
797 { self.check_and_launch_for_op("sub_a_b", a, b, c, n, m) }
798
799 fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
800 { self.check_and_launch_for_op("sub_at_b", a, b, c, n, m) }
801
802 fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
803 { self.check_and_launch_for_op("sub_a_bt", a, b, c, n, m) }
804
805 fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
806 { self.check_and_launch_for_op("sub_at_bt", a, b, c, n, m) }
807
808 fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
809 {
810 if self.has_cublas {
811 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, false)
812 } else {
813 self.check_and_launch_for_mul("mul_a_b", a, b, c, n, m, l)
814 }
815 }
816
817 fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
818 {
819 if self.has_cublas {
820 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, false)
821 } else {
822 self.check_and_launch_for_mul("mul_at_b", a, b, c, n, m, l)
823 }
824 }
825
826 fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
827 {
828 if self.has_cublas {
829 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, true)
830 } else {
831 self.check_and_launch_for_mul("mul_a_bt", a, b, c, n, m, l)
832 }
833 }
834
835 fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
836 {
837 if self.has_cublas {
838 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, true)
839 } else {
840 self.check_and_launch_for_mul("mul_at_bt", a, b, c, n, m, l)
841 }
842 }
843
844 fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
845 { self.check_and_launch_for_op("mul_a_b_for_elems", a, b, c, n, m) }
846
847 fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
848 { self.check_and_launch_for_op("mul_at_b_for_elems", a, b, c, n, m) }
849
850 fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
851 { self.check_and_launch_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
852
853 fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
854 { self.check_and_launch_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
855
856 fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
857 { self.check_and_launch_for_op("div_a_b_for_elems", a, b, c, n, m) }
858
859 fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
860 { self.check_and_launch_for_op("div_at_b_for_elems", a, b, c, n, m) }
861
862 fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
863 { self.check_and_launch_for_op("div_a_bt_for_elems", a, b, c, n, m) }
864
865 fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
866 { self.check_and_launch_for_op("div_at_bt_for_elems", a, b, c, n, m) }
867
868 fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
869 { self.check_and_launch_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
870
871 fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
872 { self.check_and_launch_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
873
874 fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
875 { self.check_and_launch_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
876
877 fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
878 { self.check_and_launch_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
879
880 fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
881 { self.check_and_launch_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
882
883 fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
884 { self.check_and_launch_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
885
886 fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
887 { self.check_and_launch_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
888
889 fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
890 { self.check_and_launch_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
891
892 fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
893 { self.check_and_launch_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
894
895 fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
896 { self.check_and_launch_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
897
898 fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
899 { self.check_and_launch_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
900
901 fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
902 { self.check_and_launch_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
903
904 fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
905 { self.check_and_launch_for_fun("sigmoid_a", a, b, n, m) }
906
907 fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
908 { self.check_and_launch_for_fun("sigmoid_at", a, b, n, m) }
909
910 fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
911 { self.check_and_launch_for_fun("tanh_a", a, b, n, m) }
912
913 fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
914 { self.check_and_launch_for_fun("tanh_at", a, b, n, m) }
915
916 fn swish_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
917 { self.check_and_launch_for_fun("swish_a", a, b, n, m) }
918
919 fn swish_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
920 { self.check_and_launch_for_fun("swish_at", a, b, n, m) }
921
922 fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
923 { self.check_and_launch_for_fun_and_tiles("softmax_a", a, b, n, m) }
924
925 fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
926 { self.check_and_launch_for_fun_and_tiles("softmax_at", a, b, n, m) }
927
928 fn sqrt_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
929 { self.check_and_launch_for_fun("sqrt_a", a, b, n, m) }
930
931 fn sqrt_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
932 { self.check_and_launch_for_fun("sqrt_at", a, b, n, m) }
933
934 fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
935 { self.check_and_launch_for_repeat_col("repeat_col_a", a, b, n, m) }
936
937 fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
938 { self.check_and_launch_for_repeat_row("repeat_row_a", a, b, n, m) }
939
940 fn abs_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
941 { self.check_and_launch_for_fun("abs_a", a, b, n, m) }
942
943 fn abs_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
944 { self.check_and_launch_for_fun("abs_at", a, b, n, m) }
945
946 fn pow_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
947 { self.check_and_launch_for_op("pow_a_b", a, b, c, n, m) }
948
949 fn pow_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
950 { self.check_and_launch_for_op("pow_at_b", a, b, c, n, m) }
951
952 fn pow_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
953 { self.check_and_launch_for_op("pow_a_bt", a, b, c, n, m) }
954
955 fn pow_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
956 { self.check_and_launch_for_op("pow_at_bt", a, b, c, n, m) }
957
958 fn pow_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
959 { self.check_and_launch_for_scalar("pow_a_b_for_scalar", a, b, c, n, m) }
960
961 fn pow_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
962 { self.check_and_launch_for_scalar("pow_at_b_for_scalar", a, b, c, n, m) }
963
964 fn rpow_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
965 { self.check_and_launch_for_scalar("rpow_a_b_for_scalar", a, b, c, n, m) }
966
967 fn rpow_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
968 { self.check_and_launch_for_scalar("rpow_at_b_for_scalar", a, b, c, n, m) }
969
970 fn exp_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
971 { self.check_and_launch_for_fun("exp_a", a, b, n, m) }
972
973 fn exp_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
974 { self.check_and_launch_for_fun("exp_at", a, b, n, m) }
975
976 fn ln_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
977 { self.check_and_launch_for_fun("ln_a", a, b, n, m) }
978
979 fn ln_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
980 { self.check_and_launch_for_fun("ln_at", a, b, n, m) }
981
982 fn log2_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
983 { self.check_and_launch_for_fun("log2_a", a, b, n, m) }
984
985 fn log2_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
986 { self.check_and_launch_for_fun("log2_at", a, b, n, m) }
987
988 fn log10_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
989 { self.check_and_launch_for_fun("log10_a", a, b, n, m) }
990
991 fn log10_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
992 { self.check_and_launch_for_fun("log10_at", a, b, n, m) }
993
994 fn sin_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
995 { self.check_and_launch_for_fun("sin_a", a, b, n, m) }
996
997 fn sin_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
998 { self.check_and_launch_for_fun("sin_at", a, b, n, m) }
999
1000 fn cos_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1001 { self.check_and_launch_for_fun("cos_a", a, b, n, m) }
1002
1003 fn cos_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1004 { self.check_and_launch_for_fun("cos_at", a, b, n, m) }
1005
1006 fn tan_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1007 { self.check_and_launch_for_fun("tan_a", a, b, n, m) }
1008
1009 fn tan_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1010 { self.check_and_launch_for_fun("tan_at", a, b, n, m) }
1011
1012 fn asin_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1013 { self.check_and_launch_for_fun("asin_a", a, b, n, m) }
1014
1015 fn asin_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1016 { self.check_and_launch_for_fun("asin_at", a, b, n, m) }
1017
1018 fn acos_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1019 { self.check_and_launch_for_fun("acos_a", a, b, n, m) }
1020
1021 fn acos_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1022 { self.check_and_launch_for_fun("acos_at", a, b, n, m) }
1023
1024 fn atan_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1025 { self.check_and_launch_for_fun("atan_a", a, b, n, m) }
1026
1027 fn atan_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1028 { self.check_and_launch_for_fun("atan_at", a, b, n, m) }
1029
1030 fn atan2_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1031 { self.check_and_launch_for_op("atan2_a_b", a, b, c, n, m) }
1032
1033 fn atan2_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1034 { self.check_and_launch_for_op("atan2_at_b", a, b, c, n, m) }
1035
1036 fn atan2_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1037 { self.check_and_launch_for_op("atan2_a_bt", a, b, c, n, m) }
1038
1039 fn atan2_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
1040 { self.check_and_launch_for_op("atan2_at_bt", a, b, c, n, m) }
1041
1042 fn atan2_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1043 { self.check_and_launch_for_scalar("atan2_a_b_for_scalar", a, b, c, n, m) }
1044
1045 fn atan2_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1046 { self.check_and_launch_for_scalar("atan2_at_b_for_scalar", a, b, c, n, m) }
1047
1048 fn ratan2_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1049 { self.check_and_launch_for_scalar("ratan2_a_b_for_scalar", a, b, c, n, m) }
1050
1051 fn ratan2_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
1052 { self.check_and_launch_for_scalar("ratan2_at_b_for_scalar", a, b, c, n, m) }
1053
1054 fn sinh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1055 { self.check_and_launch_for_fun("sinh_a", a, b, n, m) }
1056
1057 fn sinh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1058 { self.check_and_launch_for_fun("sinh_at", a, b, n, m) }
1059
1060 fn cosh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1061 { self.check_and_launch_for_fun("cosh_a", a, b, n, m) }
1062
1063 fn cosh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1064 { self.check_and_launch_for_fun("cosh_at", a, b, n, m) }
1065
1066 fn asinh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1067 { self.check_and_launch_for_fun("asinh_a", a, b, n, m) }
1068
1069 fn asinh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1070 { self.check_and_launch_for_fun("asinh_at", a, b, n, m) }
1071
1072 fn acosh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1073 { self.check_and_launch_for_fun("acosh_a", a, b, n, m) }
1074
1075 fn acosh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1076 { self.check_and_launch_for_fun("acosh_at", a, b, n, m) }
1077
1078 fn atanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1079 { self.check_and_launch_for_fun("atanh_a", a, b, n, m) }
1080
1081 fn atanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1082 { self.check_and_launch_for_fun("atanh_at", a, b, n, m) }
1083
1084 fn signum_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1085 { self.check_and_launch_for_fun("signum_a", a, b, n, m) }
1086
1087 fn signum_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1088 { self.check_and_launch_for_fun("signum_at", a, b, n, m) }
1089
1090 fn ceil_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1091 { self.check_and_launch_for_fun("ceil_a", a, b, n, m) }
1092
1093 fn ceil_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1094 { self.check_and_launch_for_fun("ceil_at", a, b, n, m) }
1095
1096 fn floor_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1097 { self.check_and_launch_for_fun("floor_a", a, b, n, m) }
1098
1099 fn floor_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1100 { self.check_and_launch_for_fun("floor_at", a, b, n, m) }
1101
1102 fn round_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1103 { self.check_and_launch_for_fun("round_a", a, b, n, m) }
1104
1105 fn round_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1106 { self.check_and_launch_for_fun("round_at", a, b, n, m) }
1107
1108 fn trunc_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1109 { self.check_and_launch_for_fun("trunc_a", a, b, n, m) }
1110
1111 fn trunc_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
1112 { self.check_and_launch_for_fun("trunc_at", a, b, n, m) }
1113}
1114
1115#[cfg(test)]
1116mod tests;