1use std::default::Default;
10use std::ffi::c_int;
11use std::ffi::c_void;
12use std::sync::Arc;
13use std::sync::Mutex;
14use crate::Backend;
15use crate::BackendArray;
16use crate::Error;
17use crate::Result;
18use crate::mutex_lock;
19
20pub use cudarc::cublas::result::CublasError;
21pub use cudarc::driver::DriverError;
22
23use cudarc::cublas::result::sgemm;
24use cudarc::cublas::sys::cublasOperation_t;
25use cudarc::cublas::CudaBlas;
26use cudarc::driver::sys::CUdeviceptr;
27use cudarc::driver::CudaDevice;
28use cudarc::driver::CudaFunction;
29use cudarc::driver::CudaSlice;
30use cudarc::driver::DeviceRepr;
31use cudarc::driver::DevicePtr;
32use cudarc::driver::DevicePtrMut;
33use cudarc::driver::LaunchAsync;
34use cudarc::driver::LaunchConfig;
35use cudarc::nvrtc::CompileError;
36use cudarc::nvrtc::CompileOptions;
37use cudarc::nvrtc::compile_ptx_with_opts;
38
39const SOURCE: &'static str = include_str!("cuda.cu");
40
41const KERNELS: &'static [&'static str] = &[
42 "transpose_a",
43 "add_a_b",
44 "add_at_b",
45 "add_a_bt",
46 "add_at_bt",
47 "sub_a_b",
48 "sub_at_b",
49 "sub_a_bt",
50 "sub_at_bt",
51 "mul_a_b",
52 "mul_at_b",
53 "mul_a_bt",
54 "mul_at_bt",
55 "mul_a_b_for_elems",
56 "mul_at_b_for_elems",
57 "mul_a_bt_for_elems",
58 "mul_at_bt_for_elems",
59 "div_a_b_for_elems",
60 "div_at_b_for_elems",
61 "div_a_bt_for_elems",
62 "div_at_bt_for_elems",
63 "add_a_b_for_scalar",
64 "add_at_b_for_scalar",
65 "sub_a_b_for_scalar",
66 "sub_at_b_for_scalar",
67 "rsub_a_b_for_scalar",
68 "rsub_at_b_for_scalar",
69 "mul_a_b_for_scalar",
70 "mul_at_b_for_scalar",
71 "div_a_b_for_scalar",
72 "div_at_b_for_scalar",
73 "rdiv_a_b_for_scalar",
74 "rdiv_at_b_for_scalar",
75 "sigmoid_a",
76 "sigmoid_at",
77 "tanh_a",
78 "tanh_at",
79 "swish_a",
80 "swish_at",
81 "softmax_a",
82 "softmax_at",
83 "repeat_col_a",
84 "repeat_row_a"
85];
86
87#[derive(Debug)]
91pub struct CudaBackendArray
92{
93 slice: Arc<Mutex<CudaSlice<f32>>>,
94 len: usize,
95}
96
97struct CudaInnerBackend
98{
99 device: Arc<CudaDevice>,
100 cublas: Option<CudaBlas>,
101}
102
103pub struct CudaBackend
105{
106 inner: Mutex<CudaInnerBackend>,
107 has_cublas: bool,
108 has_mma: bool,
109}
110
111fn preferred_launch_config(n: usize, m: usize, is_mul: bool, is_mma: bool) -> LaunchConfig
112{
113 if m == 1 && !is_mul {
114 let n2 = ((n + 1023) / 1024) as u32;
115 LaunchConfig {
116 grid_dim: (n2, 1, 1),
117 block_dim: (1024, 1, 1),
118 shared_mem_bytes: 0,
119 }
120 } else if n == 1 && !is_mul {
121 let m2 = ((m + 1023) / 1024) as u32;
122 LaunchConfig {
123 grid_dim: (1, m2, 1),
124 block_dim: (1, 1024, 1),
125 shared_mem_bytes: 0,
126 }
127 } else if is_mul {
128 if is_mma {
129 let n2 = ((n + 63) / 64) as u32;
130 let m2 = ((m + 63) / 64) as u32;
131 LaunchConfig {
132 grid_dim: (n2, m2, 1),
133 block_dim: (1024, 1, 1),
134 shared_mem_bytes: 0,
135 }
136 } else {
137 let n2 = (((n + 3) / 4 + 15) / 16) as u32;
138 let m2 = (((m + 3) / 4 + 15) / 16) as u32;
139 LaunchConfig {
140 grid_dim: (n2, m2, 1),
141 block_dim: (16, 16, 1),
142 shared_mem_bytes: 0,
143 }
144 }
145 } else {
146 let n2 = ((n + 31) / 32) as u32;
147 let m2 = ((m + 31) / 32) as u32;
148 LaunchConfig {
149 grid_dim: (n2, m2, 1),
150 block_dim: (32, 32, 1),
151 shared_mem_bytes: 0,
152 }
153 }
154}
155
156impl CudaBackend
157{
158 pub fn new() -> Result<CudaBackend>
160 {
161 if cfg!(feature = "default_cublas") {
162 Self::new_with_ordinal_and_flags(0, true, false)
163 } else if cfg!(feature = "default_mma") {
164 Self::new_with_ordinal_and_flags(0, false, true)
165 } else {
166 Self::new_with_ordinal_and_flags(0, false, false)
167 }
168 }
169
170 pub fn new_with_ordinal_and_flags(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<CudaBackend>
177 {
178 let device = match CudaDevice::new(ordinal) {
179 Ok(tmp_device) => tmp_device,
180 Err(err) => return Err(Error::Cuda(err)),
181 };
182 let mut options: CompileOptions = Default::default();
183 if is_mma {
184 options.options = vec![String::from("-DUNMTX_GPU_MMA=1")];
185 options.arch = Some("sm_80");
186 }
187 let ptx = match compile_ptx_with_opts(SOURCE, options) {
188 Ok(tmp_ptx) => tmp_ptx,
189 Err(CompileError::CompileError { log, .. }) => return Err(Error::Compilation(log.as_c_str().to_string_lossy().into_owned())),
190 Err(err) => return Err(Error::Compilation(format!("{}", err))),
191 };
192 match device.load_ptx(ptx, "unmtx_gpu", KERNELS) {
193 Ok(()) => (),
194 Err(err) => return Err(Error::Cuda(err)),
195 }
196 let cublas = if is_cublas {
197 match CudaBlas::new(device.clone()) {
198 Ok(tmp_cublas) => Some(tmp_cublas),
199 Err(err) => return Err(Error::Cublas(err)),
200 }
201 } else {
202 None
203 };
204 Ok(CudaBackend { inner: Mutex::new(CudaInnerBackend { device, cublas, }), has_cublas: is_cublas, has_mma: is_mma, })
205 }
206
207 pub fn has_cublas(&self) -> bool
208 { self.has_cublas }
209
210 fn check_and_launch2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
211 where F: FnOnce(&CudaBackendArray, &CudaBackendArray) -> Result<()>,
212 G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void) -> Result<()>
213 {
214 #[allow(unreachable_patterns)]
215 match (a, b) {
216 (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
217 f(a2, b2)?;
218 let inner_g = mutex_lock(&self.inner)?;
219 let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
220 Some(tmp_kernel) => tmp_kernel,
221 None => return Err(Error::NoKernel(String::from(kernel_name))),
222 };
223 if !Arc::ptr_eq(&a2.slice, &b2.slice) {
224 let a_slice_g = mutex_lock(&a2.slice)?;
225 let mut b_slice_g = mutex_lock(&b2.slice)?;
226 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?;
227 } else {
228 let mut a_slice_g = mutex_lock(&a2.slice)?;
229 g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?;
230 }
231 match inner_g.device.synchronize() {
232 Ok(()) => (),
233 Err(err) => return Err(Error::Cuda(err)),
234 }
235 Ok(())
236 },
237 _ => Err(Error::InvalidBackendArray),
238 }
239 }
240
241 fn check_and_launch3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
242 where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
243 G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void, *mut c_void) -> Result<()>
244 {
245 #[allow(unreachable_patterns)]
246 match (a, b, c) {
247 (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
248 f(a2, b2, c2)?;
249 let inner_g = mutex_lock(&self.inner)?;
250 let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
251 Some(tmp_kernel) => tmp_kernel,
252 None => return Err(Error::NoKernel(String::from(kernel_name))),
253 };
254 match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
255 (false, false, false) => {
256 let a_slice_g = mutex_lock(&a2.slice)?;
257 let b_slice_g = mutex_lock(&b2.slice)?;
258 let mut c_slice_g = mutex_lock(&c2.slice)?;
259 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
260 },
261 (true, false, false) => {
262 let a_slice_g = mutex_lock(&a2.slice)?;
263 let mut c_slice_g = mutex_lock(&c2.slice)?;
264 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*a_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
265 },
266 (false, true, false) => {
267 let mut a_slice_g = mutex_lock(&a2.slice)?;
268 let b_slice_g = mutex_lock(&b2.slice)?;
269 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
270 },
271 (false, false, true) => {
272 let a_slice_g = mutex_lock(&a2.slice)?;
273 let mut b_slice_g = mutex_lock(&b2.slice)?;
274 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?
275 },
276 _ => {
277 let mut a_slice_g = mutex_lock(&a2.slice)?;
278 g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
279 },
280 }
281 match inner_g.device.synchronize() {
282 Ok(()) => (),
283 Err(err) => return Err(Error::Cuda(err)),
284 }
285 Ok(())
286 },
287 _ => Err(Error::InvalidBackendArray),
288 }
289 }
290
291 fn check_and_launch_cublas3<F, G>(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
292 where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
293 G: FnOnce(&CudaInnerBackend, CUdeviceptr, CUdeviceptr, CUdeviceptr) -> Result<()>
294 {
295 #[allow(unreachable_patterns)]
296 match (a, b, c) {
297 (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
298 f(a2, b2, c2)?;
299 let inner_g = mutex_lock(&self.inner)?;
300 match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
301 (false, false, false) => {
302 let a_slice_g = mutex_lock(&a2.slice)?;
303 let b_slice_g = mutex_lock(&b2.slice)?;
304 let mut c_slice_g = mutex_lock(&c2.slice)?;
305 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
306 let b_device_ptr = *(&(*b_slice_g)).device_ptr();
307 let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
308 g(&*inner_g, a_device_ptr, b_device_ptr, c_device_ptr)?
309 },
310 (true, false, false) => {
311 let a_slice_g = mutex_lock(&a2.slice)?;
312 let mut c_slice_g = mutex_lock(&c2.slice)?;
313 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
314 let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
315 g(&*inner_g, a_device_ptr, a_device_ptr, c_device_ptr)?
316 },
317 (false, true, false) => {
318 let mut a_slice_g = mutex_lock(&a2.slice)?;
319 let b_slice_g = mutex_lock(&b2.slice)?;
320 let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
321 let b_device_ptr = *(&(*b_slice_g)).device_ptr();
322 g(&*inner_g, a_device_ptr, b_device_ptr, a_device_ptr)?
323 },
324 (false, false, true) => {
325 let a_slice_g = mutex_lock(&a2.slice)?;
326 let mut b_slice_g = mutex_lock(&b2.slice)?;
327 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
328 let b_device_ptr = *(&mut (*b_slice_g)).device_ptr_mut();
329 g(&*inner_g, a_device_ptr, b_device_ptr, b_device_ptr)?
330 },
331 _ => {
332 let mut a_slice_g = mutex_lock(&a2.slice)?;
333 let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
334 g(&*inner_g, a_device_ptr, a_device_ptr, a_device_ptr)?
335 },
336 }
337 match inner_g.device.synchronize() {
338 Ok(()) => (),
339 Err(err) => return Err(Error::Cuda(err)),
340 }
341 Ok(())
342 },
343 _ => Err(Error::InvalidBackendArray),
344 }
345 }
346
347 fn check_and_launch_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
348 {
349 let is_mma = self.has_mma;
350 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
351 if a2.len != n * m {
352 return Err(Error::BackendArrayElemCount(a2.len, n * m));
353 }
354 if b2.len != n * m {
355 return Err(Error::BackendArrayElemCount(b2.len, n * m));
356 }
357 Ok(())
358 }, |_, kernel, a_param, b_param| {
359 let config = preferred_launch_config(n, m, false, is_mma);
360 let mut params = vec![
361 a_param,
362 b_param,
363 n.as_kernel_param(),
364 m.as_kernel_param()
365 ];
366 unsafe {
367 match kernel.launch(config, &mut params) {
368 Ok(()) => Ok(()),
369 Err(err) => Err(Error::Cuda(err)),
370 }
371 }
372 })
373 }
374
375 fn check_and_launch_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
376 {
377 let is_mma = self.has_mma;
378 self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
379 if a2.len != n * m {
380 return Err(Error::BackendArrayElemCount(a2.len, n * m));
381 }
382 if b2.len != n * m {
383 return Err(Error::BackendArrayElemCount(b2.len, n * m));
384 }
385 if c2.len != n * m {
386 return Err(Error::BackendArrayElemCount(c2.len, n * m));
387 }
388 Ok(())
389 }, |_, kernel, a_param, b_param, c_param| {
390 let config = preferred_launch_config(n, m, false, is_mma);
391 let mut params = vec![
392 a_param,
393 b_param,
394 c_param,
395 n.as_kernel_param(),
396 m.as_kernel_param()
397 ];
398 unsafe {
399 match kernel.launch(config, &mut params) {
400 Ok(()) => Ok(()),
401 Err(err) => Err(Error::Cuda(err)),
402 }
403 }
404 })
405 }
406
407 fn check_and_launch_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
408 {
409 let is_mma = self.has_mma;
410 self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
411 if a2.len != n * l {
412 return Err(Error::BackendArrayElemCount(a2.len, n * l));
413 }
414 if b2.len != l * m {
415 return Err(Error::BackendArrayElemCount(b2.len, l * m));
416 }
417 if c2.len != n * m {
418 return Err(Error::BackendArrayElemCount(c2.len, n * m));
419 }
420 Ok(())
421 }, |_, kernel, a_param, b_param, c_param| {
422 let config = preferred_launch_config(n, m, true, is_mma);
423 let mut params = vec![
424 a_param,
425 b_param,
426 c_param,
427 n.as_kernel_param(),
428 m.as_kernel_param(),
429 l.as_kernel_param()
430 ];
431 unsafe {
432 match kernel.launch(config, &mut params) {
433 Ok(()) => Ok(()),
434 Err(err) => Err(Error::Cuda(err)),
435 }
436 }
437 })
438 }
439
440 fn check_and_launch_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
441 {
442 let is_mma = self.has_mma;
443 self.check_and_launch2(kernel_name, a, c, |a2, c2| {
444 if a2.len != n * m {
445 return Err(Error::BackendArrayElemCount(a2.len, n * m));
446 }
447 if c2.len != n * m {
448 return Err(Error::BackendArrayElemCount(c2.len, n * m));
449 }
450 Ok(())
451 }, |_, kernel, a_param, c_param| {
452 let config = preferred_launch_config(n, m, false, is_mma);
453 let mut params = vec![
454 a_param,
455 b.as_kernel_param(),
456 c_param,
457 n.as_kernel_param(),
458 m.as_kernel_param()
459 ];
460 unsafe {
461 match kernel.launch(config, &mut params) {
462 Ok(()) => Ok(()),
463 Err(err) => Err(Error::Cuda(err)),
464 }
465 }
466 })
467 }
468
469 fn check_and_launch_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
470 {
471 let is_mma = self.has_mma;
472 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
473 if a2.len != n * m {
474 return Err(Error::BackendArrayElemCount(a2.len, n * m));
475 }
476 if b2.len != n * m {
477 return Err(Error::BackendArrayElemCount(b2.len, n * m));
478 }
479 Ok(())
480 }, |_, kernel, a_param, b_param| {
481 let config = preferred_launch_config(n, m, false, is_mma);
482 let mut params = vec![
483 a_param,
484 b_param,
485 n.as_kernel_param(),
486 m.as_kernel_param(),
487 ((config.block_dim.1) as usize).as_kernel_param(),
488 ((config.block_dim.0) as usize).as_kernel_param()
489 ];
490 unsafe {
491 match kernel.launch(config, &mut params) {
492 Ok(()) => Ok(()),
493 Err(err) => Err(Error::Cuda(err)),
494 }
495 }
496 })
497 }
498
499 fn check_and_launch_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
500 {
501 let is_mma = self.has_mma;
502 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
503 if a2.len != n {
504 return Err(Error::BackendArrayElemCount(a2.len, n));
505 }
506 if b2.len != n * m {
507 return Err(Error::BackendArrayElemCount(b2.len, n * m));
508 }
509 Ok(())
510 }, |_, kernel, a_param, b_param| {
511 let config = preferred_launch_config(n, m, false, is_mma);
512 let mut params = vec![
513 a_param,
514 b_param,
515 n.as_kernel_param(),
516 m.as_kernel_param()
517 ];
518 unsafe {
519 match kernel.launch(config, &mut params) {
520 Ok(()) => Ok(()),
521 Err(err) => Err(Error::Cuda(err)),
522 }
523 }
524 })
525 }
526
527 fn check_and_launch_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
528 {
529 let is_mma = self.has_mma;
530 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
531 if a2.len != m {
532 return Err(Error::BackendArrayElemCount(a2.len, m));
533 }
534 if b2.len != n * m {
535 return Err(Error::BackendArrayElemCount(b2.len, n * m));
536 }
537 Ok(())
538 }, |_, kernel, a_param, b_param| {
539 let config = preferred_launch_config(n, m, false, is_mma);
540 let mut params = vec![
541 a_param,
542 b_param,
543 n.as_kernel_param(),
544 m.as_kernel_param()
545 ];
546 unsafe {
547 match kernel.launch(config, &mut params) {
548 Ok(()) => Ok(()),
549 Err(err) => Err(Error::Cuda(err)),
550 }
551 }
552 })
553 }
554
555 fn check_and_launch_cublas_for_mul(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize, is_trans_a: bool, is_trans_b: bool) -> Result<()>
556 {
557 self.check_and_launch_cublas3(a, b, c, |a2, b2, c2| {
558 if a2.len != n * l {
559 return Err(Error::BackendArrayElemCount(a2.len, n * l));
560 }
561 if b2.len != l * m {
562 return Err(Error::BackendArrayElemCount(b2.len, l * m));
563 }
564 if c2.len != n * m {
565 return Err(Error::BackendArrayElemCount(c2.len, n * m));
566 }
567 Ok(())
568 }, |inner, a_device_ptr, b_device_ptr, c_device_ptr| {
569 unsafe {
570 match &inner.cublas {
571 Some(cublas) => {
572 let (transa, lda) = if is_trans_a {
573 (cublasOperation_t::CUBLAS_OP_T, n as c_int)
574 } else {
575 (cublasOperation_t::CUBLAS_OP_N, l as c_int)
576 };
577 let (transb, ldb) = if is_trans_b {
578 (cublasOperation_t::CUBLAS_OP_T, l as c_int)
579 } else {
580 (cublasOperation_t::CUBLAS_OP_N, m as c_int)
581 };
582 let alpha = 1.0f32;
583 let beta = 0.0f32;
584 let res = sgemm(*cublas.handle(),
585 transb, transa,
586 m as c_int, n as c_int, l as c_int,
587 (&alpha) as *const _,
588 b_device_ptr as *const _, ldb,
589 a_device_ptr as *const _, lda,
590 (&beta) as *const _,
591 c_device_ptr as *mut _, m as c_int);
592 match res {
593 Ok(()) => Ok(()),
594 Err(err) => Err(Error::Cublas(err)),
595 }
596 },
597 None => Err(Error::NoCublas),
598 }
599 }
600 })
601 }
602}
603
604impl Backend for CudaBackend
605{
606 fn name(&self) -> &'static str
607 {
608 if self.has_cublas {
609 "CUDA(cuBLAS)"
610 } else if self.has_mma {
611 "CUDA(mma)"
612 } else {
613 "CUDA"
614 }
615 }
616
617 fn has_cublas(&self) -> bool
618 { self.has_cublas }
619
620 unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
621 {
622 let inner_g = mutex_lock(&self.inner)?;
623 let slice: CudaSlice<f32> = match inner_g.device.alloc(n) {
624 Ok(tmp_slice) => tmp_slice,
625 Err(err) => return Err(Error::Cuda(err)),
626 };
627 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
628 Ok(BackendArray::Cuda(cuda_array))
629 }
630
631 fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
632 {
633 let inner_g = mutex_lock(&self.inner)?;
634 let slice: CudaSlice<f32> = match inner_g.device.alloc_zeros(n) {
635 Ok(tmp_slice) => tmp_slice,
636 Err(err) => return Err(Error::Cuda(err)),
637 };
638 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
639 Ok(BackendArray::Cuda(cuda_array))
640 }
641
642 fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
643 {
644 let inner_g = mutex_lock(&self.inner)?;
645 let slice: CudaSlice<f32> = match inner_g.device.htod_sync_copy(elems) {
646 Ok(tmp_slice) => tmp_slice,
647 Err(err) => return Err(Error::Cuda(err)),
648 };
649 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: elems.len(), };
650 Ok(BackendArray::Cuda(cuda_array))
651 }
652
653 fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
654 {
655 #[allow(unreachable_patterns)]
656 match a {
657 BackendArray::Cuda(a2) => {
658 if a2.len != elems.len() {
659 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
660 }
661 let inner_g = mutex_lock(&self.inner)?;
662 let a_slice_g = mutex_lock(&a2.slice)?;
663 match inner_g.device.dtoh_sync_copy_into(&(*a_slice_g), elems) {
664 Ok(()) => (),
665 Err(err) => return Err(Error::Cuda(err)),
666 }
667 },
668 _ => return Err(Error::InvalidBackendArray),
669 }
670 Ok(())
671 }
672
673 fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
674 {
675 #[allow(unreachable_patterns)]
676 match a {
677 BackendArray::Cuda(a2) => {
678 if a2.len != elems.len() {
679 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
680 }
681 let inner_g = mutex_lock(&self.inner)?;
682 let mut a_slice_g = mutex_lock(&a2.slice)?;
683 match inner_g.device.htod_sync_copy_into(elems, &mut (*a_slice_g)) {
684 Ok(()) => (),
685 Err(err) => return Err(Error::Cuda(err)),
686 }
687 },
688 _ => return Err(Error::InvalidBackendArray),
689 }
690 Ok(())
691 }
692
693 fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
694 {
695 #[allow(unreachable_patterns)]
696 match (a, b) {
697 (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
698 if Arc::ptr_eq(&a2.slice, &b2.slice) {
699 return Ok(());
700 }
701 if a2.len != b2.len {
702 return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
703 }
704 let inner_g = mutex_lock(&self.inner)?;
705 let a_slice_g = mutex_lock(&a2.slice)?;
706 let mut b_slice_g = mutex_lock(&b2.slice)?;
707 match inner_g.device.dtod_copy(&(*a_slice_g), &mut (*b_slice_g)) {
708 Ok(()) => (),
709 Err(err) => return Err(Error::Cuda(err)),
710 }
711 match inner_g.device.synchronize() {
712 Ok(()) => (),
713 Err(err) => return Err(Error::Cuda(err)),
714 }
715 },
716 _ => return Err(Error::InvalidBackendArray),
717 }
718 Ok(())
719 }
720
721 fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
722 { self.check_and_launch_for_fun("transpose_a", a, b, n, m) }
723
724 fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
725 { self.check_and_launch_for_op("add_a_b", a, b, c, n, m) }
726
727 fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
728 { self.check_and_launch_for_op("add_at_b", a, b, c, n, m) }
729
730 fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
731 { self.check_and_launch_for_op("add_a_bt", a, b, c, n, m) }
732
733 fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
734 { self.check_and_launch_for_op("add_at_bt", a, b, c, n, m) }
735
736 fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
737 { self.check_and_launch_for_op("sub_a_b", a, b, c, n, m) }
738
739 fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
740 { self.check_and_launch_for_op("sub_at_b", a, b, c, n, m) }
741
742 fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
743 { self.check_and_launch_for_op("sub_a_bt", a, b, c, n, m) }
744
745 fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
746 { self.check_and_launch_for_op("sub_at_bt", a, b, c, n, m) }
747
748 fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
749 {
750 if self.has_cublas {
751 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, false)
752 } else {
753 self.check_and_launch_for_mul("mul_a_b", a, b, c, n, m, l)
754 }
755 }
756
757 fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
758 {
759 if self.has_cublas {
760 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, false)
761 } else {
762 self.check_and_launch_for_mul("mul_at_b", a, b, c, n, m, l)
763 }
764 }
765
766 fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
767 {
768 if self.has_cublas {
769 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, true)
770 } else {
771 self.check_and_launch_for_mul("mul_a_bt", a, b, c, n, m, l)
772 }
773 }
774
775 fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
776 {
777 if self.has_cublas {
778 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, true)
779 } else {
780 self.check_and_launch_for_mul("mul_at_bt", a, b, c, n, m, l)
781 }
782 }
783
784 fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
785 { self.check_and_launch_for_op("mul_a_b_for_elems", a, b, c, n, m) }
786
787 fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
788 { self.check_and_launch_for_op("mul_at_b_for_elems", a, b, c, n, m) }
789
790 fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
791 { self.check_and_launch_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
792
793 fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
794 { self.check_and_launch_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
795
796 fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
797 { self.check_and_launch_for_op("div_a_b_for_elems", a, b, c, n, m) }
798
799 fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
800 { self.check_and_launch_for_op("div_at_b_for_elems", a, b, c, n, m) }
801
802 fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
803 { self.check_and_launch_for_op("div_a_bt_for_elems", a, b, c, n, m) }
804
805 fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
806 { self.check_and_launch_for_op("div_at_bt_for_elems", a, b, c, n, m) }
807
808 fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
809 { self.check_and_launch_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
810
811 fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
812 { self.check_and_launch_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
813
814 fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
815 { self.check_and_launch_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
816
817 fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
818 { self.check_and_launch_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
819
820 fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
821 { self.check_and_launch_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
822
823 fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
824 { self.check_and_launch_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
825
826 fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
827 { self.check_and_launch_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
828
829 fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
830 { self.check_and_launch_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
831
832 fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
833 { self.check_and_launch_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
834
835 fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
836 { self.check_and_launch_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
837
838 fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
839 { self.check_and_launch_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
840
841 fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
842 { self.check_and_launch_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
843
844 fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
845 { self.check_and_launch_for_fun("sigmoid_a", a, b, n, m) }
846
847 fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
848 { self.check_and_launch_for_fun("sigmoid_at", a, b, n, m) }
849
850 fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
851 { self.check_and_launch_for_fun("tanh_a", a, b, n, m) }
852
853 fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
854 { self.check_and_launch_for_fun("tanh_at", a, b, n, m) }
855
856 fn swish_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
857 { self.check_and_launch_for_fun("swish_a", a, b, n, m) }
858
859 fn swish_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
860 { self.check_and_launch_for_fun("swish_at", a, b, n, m) }
861
862 fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
863 { self.check_and_launch_for_fun_and_tiles("softmax_a", a, b, n, m) }
864
865 fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
866 { self.check_and_launch_for_fun_and_tiles("softmax_at", a, b, n, m) }
867
868 fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
869 { self.check_and_launch_for_repeat_col("repeat_col_a", a, b, n, m) }
870
871 fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
872 { self.check_and_launch_for_repeat_row("repeat_row_a", a, b, n, m) }
873}
874
875#[cfg(test)]
876mod tests;