1use std::default::Default;
10use std::ffi::c_int;
11use std::ffi::c_void;
12use std::sync::Arc;
13use std::sync::Mutex;
14use crate::Backend;
15use crate::BackendArray;
16use crate::Error;
17use crate::Result;
18use crate::mutex_lock;
19
20pub use cudarc::cublas::result::CublasError;
21pub use cudarc::driver::DriverError;
22
23use cudarc::cublas::result::sgemm;
24use cudarc::cublas::sys::cublasOperation_t;
25use cudarc::cublas::CudaBlas;
26use cudarc::driver::sys::CUdeviceptr;
27use cudarc::driver::CudaDevice;
28use cudarc::driver::CudaFunction;
29use cudarc::driver::CudaSlice;
30use cudarc::driver::DeviceRepr;
31use cudarc::driver::DevicePtr;
32use cudarc::driver::DevicePtrMut;
33use cudarc::driver::LaunchAsync;
34use cudarc::driver::LaunchConfig;
35use cudarc::nvrtc::CompileError;
36use cudarc::nvrtc::CompileOptions;
37use cudarc::nvrtc::compile_ptx_with_opts;
38
39const SOURCE: &'static str = include_str!("cuda.cu");
40
41const KERNELS: &'static [&'static str] = &[
42 "transpose_a",
43 "add_a_b",
44 "add_at_b",
45 "add_a_bt",
46 "add_at_bt",
47 "sub_a_b",
48 "sub_at_b",
49 "sub_a_bt",
50 "sub_at_bt",
51 "mul_a_b",
52 "mul_at_b",
53 "mul_a_bt",
54 "mul_at_bt",
55 "mul_a_b_for_elems",
56 "mul_at_b_for_elems",
57 "mul_a_bt_for_elems",
58 "mul_at_bt_for_elems",
59 "div_a_b_for_elems",
60 "div_at_b_for_elems",
61 "div_a_bt_for_elems",
62 "div_at_bt_for_elems",
63 "add_a_b_for_scalar",
64 "add_at_b_for_scalar",
65 "sub_a_b_for_scalar",
66 "sub_at_b_for_scalar",
67 "rsub_a_b_for_scalar",
68 "rsub_at_b_for_scalar",
69 "mul_a_b_for_scalar",
70 "mul_at_b_for_scalar",
71 "div_a_b_for_scalar",
72 "div_at_b_for_scalar",
73 "rdiv_a_b_for_scalar",
74 "rdiv_at_b_for_scalar",
75 "sigmoid_a",
76 "sigmoid_at",
77 "tanh_a",
78 "tanh_at",
79 "softmax_a",
80 "softmax_at",
81 "repeat_col_a",
82 "repeat_row_a"
83];
84
85#[derive(Debug)]
89pub struct CudaBackendArray
90{
91 slice: Arc<Mutex<CudaSlice<f32>>>,
92 len: usize,
93}
94
95struct CudaInnerBackend
96{
97 device: Arc<CudaDevice>,
98 cublas: Option<CudaBlas>,
99}
100
101pub struct CudaBackend
103{
104 inner: Mutex<CudaInnerBackend>,
105 has_cublas: bool,
106 has_mma: bool,
107}
108
109fn preferred_launch_config(n: usize, m: usize, is_mul: bool, is_mma: bool) -> LaunchConfig
110{
111 if m == 1 && !is_mul {
112 let n2 = ((n + 1023) / 1024) as u32;
113 LaunchConfig {
114 grid_dim: (n2, 1, 1),
115 block_dim: (1024, 1, 1),
116 shared_mem_bytes: 0,
117 }
118 } else if n == 1 && !is_mul {
119 let m2 = ((m + 1023) / 1024) as u32;
120 LaunchConfig {
121 grid_dim: (1, m2, 1),
122 block_dim: (1, 1024, 1),
123 shared_mem_bytes: 0,
124 }
125 } else if is_mul {
126 if is_mma {
127 let n2 = ((n + 63) / 64) as u32;
128 let m2 = ((m + 63) / 64) as u32;
129 LaunchConfig {
130 grid_dim: (n2, m2, 1),
131 block_dim: (1024, 1, 1),
132 shared_mem_bytes: 0,
133 }
134 } else {
135 let n2 = (((n + 3) / 4 + 15) / 16) as u32;
136 let m2 = (((m + 3) / 4 + 15) / 16) as u32;
137 LaunchConfig {
138 grid_dim: (n2, m2, 1),
139 block_dim: (16, 16, 1),
140 shared_mem_bytes: 0,
141 }
142 }
143 } else {
144 let n2 = ((n + 31) / 32) as u32;
145 let m2 = ((m + 31) / 32) as u32;
146 LaunchConfig {
147 grid_dim: (n2, m2, 1),
148 block_dim: (32, 32, 1),
149 shared_mem_bytes: 0,
150 }
151 }
152}
153
154impl CudaBackend
155{
156 pub fn new() -> Result<CudaBackend>
158 {
159 if cfg!(feature = "default_cublas") {
160 Self::new_with_ordinal_and_flags(0, true, false)
161 } else if cfg!(feature = "default_mma") {
162 Self::new_with_ordinal_and_flags(0, false, true)
163 } else {
164 Self::new_with_ordinal_and_flags(0, false, false)
165 }
166 }
167
168 pub fn new_with_ordinal_and_flags(ordinal: usize, is_cublas: bool, is_mma: bool) -> Result<CudaBackend>
175 {
176 let device = match CudaDevice::new(ordinal) {
177 Ok(tmp_device) => tmp_device,
178 Err(err) => return Err(Error::Cuda(err)),
179 };
180 let mut options: CompileOptions = Default::default();
181 if is_mma {
182 options.options = vec![String::from("-DUNMTX_GPU_MMA=1")];
183 options.arch = Some("sm_80");
184 }
185 let ptx = match compile_ptx_with_opts(SOURCE, options) {
186 Ok(tmp_ptx) => tmp_ptx,
187 Err(CompileError::CompileError { log, .. }) => return Err(Error::Compilation(log.as_c_str().to_string_lossy().into_owned())),
188 Err(err) => return Err(Error::Compilation(format!("{}", err))),
189 };
190 match device.load_ptx(ptx, "unmtx_gpu", KERNELS) {
191 Ok(()) => (),
192 Err(err) => return Err(Error::Cuda(err)),
193 }
194 let cublas = if is_cublas {
195 match CudaBlas::new(device.clone()) {
196 Ok(tmp_cublas) => Some(tmp_cublas),
197 Err(err) => return Err(Error::Cublas(err)),
198 }
199 } else {
200 None
201 };
202 Ok(CudaBackend { inner: Mutex::new(CudaInnerBackend { device, cublas, }), has_cublas: is_cublas, has_mma: is_mma, })
203 }
204
205 pub fn has_cublas(&self) -> bool
206 { self.has_cublas }
207
208 fn check_and_launch2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
209 where F: FnOnce(&CudaBackendArray, &CudaBackendArray) -> Result<()>,
210 G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void) -> Result<()>
211 {
212 #[allow(unreachable_patterns)]
213 match (a, b) {
214 (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
215 f(a2, b2)?;
216 let inner_g = mutex_lock(&self.inner)?;
217 let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
218 Some(tmp_kernel) => tmp_kernel,
219 None => return Err(Error::NoKernel(String::from(kernel_name))),
220 };
221 if !Arc::ptr_eq(&a2.slice, &b2.slice) {
222 let a_slice_g = mutex_lock(&a2.slice)?;
223 let mut b_slice_g = mutex_lock(&b2.slice)?;
224 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?;
225 } else {
226 let mut a_slice_g = mutex_lock(&a2.slice)?;
227 g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?;
228 }
229 match inner_g.device.synchronize() {
230 Ok(()) => (),
231 Err(err) => return Err(Error::Cuda(err)),
232 }
233 Ok(())
234 },
235 _ => Err(Error::InvalidBackendArray),
236 }
237 }
238
239 fn check_and_launch3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
240 where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
241 G: FnOnce(&CudaInnerBackend, CudaFunction, *mut c_void, *mut c_void, *mut c_void) -> Result<()>
242 {
243 #[allow(unreachable_patterns)]
244 match (a, b, c) {
245 (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
246 f(a2, b2, c2)?;
247 let inner_g = mutex_lock(&self.inner)?;
248 let kernel = match inner_g.device.get_func("unmtx_gpu", kernel_name) {
249 Some(tmp_kernel) => tmp_kernel,
250 None => return Err(Error::NoKernel(String::from(kernel_name))),
251 };
252 match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
253 (false, false, false) => {
254 let a_slice_g = mutex_lock(&a2.slice)?;
255 let b_slice_g = mutex_lock(&b2.slice)?;
256 let mut c_slice_g = mutex_lock(&c2.slice)?;
257 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
258 },
259 (true, false, false) => {
260 let a_slice_g = mutex_lock(&a2.slice)?;
261 let mut c_slice_g = mutex_lock(&c2.slice)?;
262 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*a_slice_g)).as_kernel_param(), (&mut (*c_slice_g)).as_kernel_param())?
263 },
264 (false, true, false) => {
265 let mut a_slice_g = mutex_lock(&a2.slice)?;
266 let b_slice_g = mutex_lock(&b2.slice)?;
267 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&(*b_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
268 },
269 (false, false, true) => {
270 let a_slice_g = mutex_lock(&a2.slice)?;
271 let mut b_slice_g = mutex_lock(&b2.slice)?;
272 g(&*inner_g, kernel, (&(*a_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param(), (&mut (*b_slice_g)).as_kernel_param())?
273 },
274 _ => {
275 let mut a_slice_g = mutex_lock(&a2.slice)?;
276 g(&*inner_g, kernel, (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param(), (&mut (*a_slice_g)).as_kernel_param())?
277 },
278 }
279 match inner_g.device.synchronize() {
280 Ok(()) => (),
281 Err(err) => return Err(Error::Cuda(err)),
282 }
283 Ok(())
284 },
285 _ => Err(Error::InvalidBackendArray),
286 }
287 }
288
289 fn check_and_launch_cublas3<F, G>(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
290 where F: FnOnce(&CudaBackendArray, &CudaBackendArray, &CudaBackendArray) -> Result<()>,
291 G: FnOnce(&CudaInnerBackend, CUdeviceptr, CUdeviceptr, CUdeviceptr) -> Result<()>
292 {
293 #[allow(unreachable_patterns)]
294 match (a, b, c) {
295 (BackendArray::Cuda(a2), BackendArray::Cuda(b2), BackendArray::Cuda(c2)) => {
296 f(a2, b2, c2)?;
297 let inner_g = mutex_lock(&self.inner)?;
298 match (Arc::ptr_eq(&a2.slice, &b2.slice), Arc::ptr_eq(&a2.slice, &c2.slice), Arc::ptr_eq(&b2.slice, &c2.slice)) {
299 (false, false, false) => {
300 let a_slice_g = mutex_lock(&a2.slice)?;
301 let b_slice_g = mutex_lock(&b2.slice)?;
302 let mut c_slice_g = mutex_lock(&c2.slice)?;
303 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
304 let b_device_ptr = *(&(*b_slice_g)).device_ptr();
305 let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
306 g(&*inner_g, a_device_ptr, b_device_ptr, c_device_ptr)?
307 },
308 (true, false, false) => {
309 let a_slice_g = mutex_lock(&a2.slice)?;
310 let mut c_slice_g = mutex_lock(&c2.slice)?;
311 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
312 let c_device_ptr = *(&mut (*c_slice_g)).device_ptr_mut();
313 g(&*inner_g, a_device_ptr, a_device_ptr, c_device_ptr)?
314 },
315 (false, true, false) => {
316 let mut a_slice_g = mutex_lock(&a2.slice)?;
317 let b_slice_g = mutex_lock(&b2.slice)?;
318 let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
319 let b_device_ptr = *(&(*b_slice_g)).device_ptr();
320 g(&*inner_g, a_device_ptr, b_device_ptr, a_device_ptr)?
321 },
322 (false, false, true) => {
323 let a_slice_g = mutex_lock(&a2.slice)?;
324 let mut b_slice_g = mutex_lock(&b2.slice)?;
325 let a_device_ptr = *(&(*a_slice_g)).device_ptr();
326 let b_device_ptr = *(&mut (*b_slice_g)).device_ptr_mut();
327 g(&*inner_g, a_device_ptr, b_device_ptr, b_device_ptr)?
328 },
329 _ => {
330 let mut a_slice_g = mutex_lock(&a2.slice)?;
331 let a_device_ptr = *(&mut (*a_slice_g)).device_ptr_mut();
332 g(&*inner_g, a_device_ptr, a_device_ptr, a_device_ptr)?
333 },
334 }
335 match inner_g.device.synchronize() {
336 Ok(()) => (),
337 Err(err) => return Err(Error::Cuda(err)),
338 }
339 Ok(())
340 },
341 _ => Err(Error::InvalidBackendArray),
342 }
343 }
344
345 fn check_and_launch_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
346 {
347 let is_mma = self.has_mma;
348 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
349 if a2.len != n * m {
350 return Err(Error::BackendArrayElemCount(a2.len, n * m));
351 }
352 if b2.len != n * m {
353 return Err(Error::BackendArrayElemCount(b2.len, n * m));
354 }
355 Ok(())
356 }, |_, kernel, a_param, b_param| {
357 let config = preferred_launch_config(n, m, false, is_mma);
358 let mut params = vec![
359 a_param,
360 b_param,
361 n.as_kernel_param(),
362 m.as_kernel_param()
363 ];
364 unsafe {
365 match kernel.launch(config, &mut params) {
366 Ok(()) => Ok(()),
367 Err(err) => Err(Error::Cuda(err)),
368 }
369 }
370 })
371 }
372
373 fn check_and_launch_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
374 {
375 let is_mma = self.has_mma;
376 self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
377 if a2.len != n * m {
378 return Err(Error::BackendArrayElemCount(a2.len, n * m));
379 }
380 if b2.len != n * m {
381 return Err(Error::BackendArrayElemCount(b2.len, n * m));
382 }
383 if c2.len != n * m {
384 return Err(Error::BackendArrayElemCount(c2.len, n * m));
385 }
386 Ok(())
387 }, |_, kernel, a_param, b_param, c_param| {
388 let config = preferred_launch_config(n, m, false, is_mma);
389 let mut params = vec![
390 a_param,
391 b_param,
392 c_param,
393 n.as_kernel_param(),
394 m.as_kernel_param()
395 ];
396 unsafe {
397 match kernel.launch(config, &mut params) {
398 Ok(()) => Ok(()),
399 Err(err) => Err(Error::Cuda(err)),
400 }
401 }
402 })
403 }
404
405 fn check_and_launch_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
406 {
407 let is_mma = self.has_mma;
408 self.check_and_launch3(kernel_name, a, b, c, |a2, b2, c2| {
409 if a2.len != n * l {
410 return Err(Error::BackendArrayElemCount(a2.len, n * l));
411 }
412 if b2.len != l * m {
413 return Err(Error::BackendArrayElemCount(b2.len, l * m));
414 }
415 if c2.len != n * m {
416 return Err(Error::BackendArrayElemCount(c2.len, n * m));
417 }
418 Ok(())
419 }, |_, kernel, a_param, b_param, c_param| {
420 let config = preferred_launch_config(n, m, true, is_mma);
421 let mut params = vec![
422 a_param,
423 b_param,
424 c_param,
425 n.as_kernel_param(),
426 m.as_kernel_param(),
427 l.as_kernel_param()
428 ];
429 unsafe {
430 match kernel.launch(config, &mut params) {
431 Ok(()) => Ok(()),
432 Err(err) => Err(Error::Cuda(err)),
433 }
434 }
435 })
436 }
437
438 fn check_and_launch_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
439 {
440 let is_mma = self.has_mma;
441 self.check_and_launch2(kernel_name, a, c, |a2, c2| {
442 if a2.len != n * m {
443 return Err(Error::BackendArrayElemCount(a2.len, n * m));
444 }
445 if c2.len != n * m {
446 return Err(Error::BackendArrayElemCount(c2.len, n * m));
447 }
448 Ok(())
449 }, |_, kernel, a_param, c_param| {
450 let config = preferred_launch_config(n, m, false, is_mma);
451 let mut params = vec![
452 a_param,
453 b.as_kernel_param(),
454 c_param,
455 n.as_kernel_param(),
456 m.as_kernel_param()
457 ];
458 unsafe {
459 match kernel.launch(config, &mut params) {
460 Ok(()) => Ok(()),
461 Err(err) => Err(Error::Cuda(err)),
462 }
463 }
464 })
465 }
466
467 fn check_and_launch_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
468 {
469 let is_mma = self.has_mma;
470 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
471 if a2.len != n * m {
472 return Err(Error::BackendArrayElemCount(a2.len, n * m));
473 }
474 if b2.len != n * m {
475 return Err(Error::BackendArrayElemCount(b2.len, n * m));
476 }
477 Ok(())
478 }, |_, kernel, a_param, b_param| {
479 let config = preferred_launch_config(n, m, false, is_mma);
480 let mut params = vec![
481 a_param,
482 b_param,
483 n.as_kernel_param(),
484 m.as_kernel_param(),
485 ((config.block_dim.1) as usize).as_kernel_param(),
486 ((config.block_dim.0) as usize).as_kernel_param()
487 ];
488 unsafe {
489 match kernel.launch(config, &mut params) {
490 Ok(()) => Ok(()),
491 Err(err) => Err(Error::Cuda(err)),
492 }
493 }
494 })
495 }
496
497 fn check_and_launch_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
498 {
499 let is_mma = self.has_mma;
500 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
501 if a2.len != n {
502 return Err(Error::BackendArrayElemCount(a2.len, n));
503 }
504 if b2.len != n * m {
505 return Err(Error::BackendArrayElemCount(b2.len, n * m));
506 }
507 Ok(())
508 }, |_, kernel, a_param, b_param| {
509 let config = preferred_launch_config(n, m, false, is_mma);
510 let mut params = vec![
511 a_param,
512 b_param,
513 n.as_kernel_param(),
514 m.as_kernel_param()
515 ];
516 unsafe {
517 match kernel.launch(config, &mut params) {
518 Ok(()) => Ok(()),
519 Err(err) => Err(Error::Cuda(err)),
520 }
521 }
522 })
523 }
524
525 fn check_and_launch_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
526 {
527 let is_mma = self.has_mma;
528 self.check_and_launch2(kernel_name, a, b, |a2, b2| {
529 if a2.len != m {
530 return Err(Error::BackendArrayElemCount(a2.len, m));
531 }
532 if b2.len != n * m {
533 return Err(Error::BackendArrayElemCount(b2.len, n * m));
534 }
535 Ok(())
536 }, |_, kernel, a_param, b_param| {
537 let config = preferred_launch_config(n, m, false, is_mma);
538 let mut params = vec![
539 a_param,
540 b_param,
541 n.as_kernel_param(),
542 m.as_kernel_param()
543 ];
544 unsafe {
545 match kernel.launch(config, &mut params) {
546 Ok(()) => Ok(()),
547 Err(err) => Err(Error::Cuda(err)),
548 }
549 }
550 })
551 }
552
553 fn check_and_launch_cublas_for_mul(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize, is_trans_a: bool, is_trans_b: bool) -> Result<()>
554 {
555 self.check_and_launch_cublas3(a, b, c, |a2, b2, c2| {
556 if a2.len != n * l {
557 return Err(Error::BackendArrayElemCount(a2.len, n * l));
558 }
559 if b2.len != l * m {
560 return Err(Error::BackendArrayElemCount(b2.len, l * m));
561 }
562 if c2.len != n * m {
563 return Err(Error::BackendArrayElemCount(c2.len, n * m));
564 }
565 Ok(())
566 }, |inner, a_device_ptr, b_device_ptr, c_device_ptr| {
567 unsafe {
568 match &inner.cublas {
569 Some(cublas) => {
570 let (transa, lda) = if is_trans_a {
571 (cublasOperation_t::CUBLAS_OP_T, n as c_int)
572 } else {
573 (cublasOperation_t::CUBLAS_OP_N, l as c_int)
574 };
575 let (transb, ldb) = if is_trans_b {
576 (cublasOperation_t::CUBLAS_OP_T, l as c_int)
577 } else {
578 (cublasOperation_t::CUBLAS_OP_N, m as c_int)
579 };
580 let alpha = 1.0f32;
581 let beta = 0.0f32;
582 let res = sgemm(*cublas.handle(),
583 transb, transa,
584 m as c_int, n as c_int, l as c_int,
585 (&alpha) as *const _,
586 b_device_ptr as *const _, ldb,
587 a_device_ptr as *const _, lda,
588 (&beta) as *const _,
589 c_device_ptr as *mut _, m as c_int);
590 match res {
591 Ok(()) => Ok(()),
592 Err(err) => Err(Error::Cublas(err)),
593 }
594 },
595 None => Err(Error::NoCublas),
596 }
597 }
598 })
599 }
600}
601
602impl Backend for CudaBackend
603{
604 fn name(&self) -> &'static str
605 {
606 if self.has_cublas {
607 "CUDA(cuBLAS)"
608 } else if self.has_mma {
609 "CUDA(mma)"
610 } else {
611 "CUDA"
612 }
613 }
614
615 fn has_cublas(&self) -> bool
616 { self.has_cublas }
617
618 unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
619 {
620 let inner_g = mutex_lock(&self.inner)?;
621 let slice: CudaSlice<f32> = match inner_g.device.alloc(n) {
622 Ok(tmp_slice) => tmp_slice,
623 Err(err) => return Err(Error::Cuda(err)),
624 };
625 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
626 Ok(BackendArray::Cuda(cuda_array))
627 }
628
629 fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
630 {
631 let inner_g = mutex_lock(&self.inner)?;
632 let slice: CudaSlice<f32> = match inner_g.device.alloc_zeros(n) {
633 Ok(tmp_slice) => tmp_slice,
634 Err(err) => return Err(Error::Cuda(err)),
635 };
636 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: n, };
637 Ok(BackendArray::Cuda(cuda_array))
638 }
639
640 fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
641 {
642 let inner_g = mutex_lock(&self.inner)?;
643 let slice: CudaSlice<f32> = match inner_g.device.htod_sync_copy(elems) {
644 Ok(tmp_slice) => tmp_slice,
645 Err(err) => return Err(Error::Cuda(err)),
646 };
647 let cuda_array = CudaBackendArray { slice: Arc::new(Mutex::new(slice)), len: elems.len(), };
648 Ok(BackendArray::Cuda(cuda_array))
649 }
650
651 fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
652 {
653 #[allow(unreachable_patterns)]
654 match a {
655 BackendArray::Cuda(a2) => {
656 if a2.len != elems.len() {
657 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
658 }
659 let inner_g = mutex_lock(&self.inner)?;
660 let a_slice_g = mutex_lock(&a2.slice)?;
661 match inner_g.device.dtoh_sync_copy_into(&(*a_slice_g), elems) {
662 Ok(()) => (),
663 Err(err) => return Err(Error::Cuda(err)),
664 }
665 },
666 _ => return Err(Error::InvalidBackendArray),
667 }
668 Ok(())
669 }
670
671 fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
672 {
673 #[allow(unreachable_patterns)]
674 match a {
675 BackendArray::Cuda(a2) => {
676 if a2.len != elems.len() {
677 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
678 }
679 let inner_g = mutex_lock(&self.inner)?;
680 let mut a_slice_g = mutex_lock(&a2.slice)?;
681 match inner_g.device.htod_sync_copy_into(elems, &mut (*a_slice_g)) {
682 Ok(()) => (),
683 Err(err) => return Err(Error::Cuda(err)),
684 }
685 },
686 _ => return Err(Error::InvalidBackendArray),
687 }
688 Ok(())
689 }
690
691 fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
692 {
693 #[allow(unreachable_patterns)]
694 match (a, b) {
695 (BackendArray::Cuda(a2), BackendArray::Cuda(b2)) => {
696 if Arc::ptr_eq(&a2.slice, &b2.slice) {
697 return Ok(());
698 }
699 if a2.len != b2.len {
700 return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
701 }
702 let inner_g = mutex_lock(&self.inner)?;
703 let a_slice_g = mutex_lock(&a2.slice)?;
704 let mut b_slice_g = mutex_lock(&b2.slice)?;
705 match inner_g.device.dtod_copy(&(*a_slice_g), &mut (*b_slice_g)) {
706 Ok(()) => (),
707 Err(err) => return Err(Error::Cuda(err)),
708 }
709 match inner_g.device.synchronize() {
710 Ok(()) => (),
711 Err(err) => return Err(Error::Cuda(err)),
712 }
713 },
714 _ => return Err(Error::InvalidBackendArray),
715 }
716 Ok(())
717 }
718
719 fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
720 { self.check_and_launch_for_fun("transpose_a", a, b, n, m) }
721
722 fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
723 { self.check_and_launch_for_op("add_a_b", a, b, c, n, m) }
724
725 fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
726 { self.check_and_launch_for_op("add_at_b", a, b, c, n, m) }
727
728 fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
729 { self.check_and_launch_for_op("add_a_bt", a, b, c, n, m) }
730
731 fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
732 { self.check_and_launch_for_op("add_at_bt", a, b, c, n, m) }
733
734 fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
735 { self.check_and_launch_for_op("sub_a_b", a, b, c, n, m) }
736
737 fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
738 { self.check_and_launch_for_op("sub_at_b", a, b, c, n, m) }
739
740 fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
741 { self.check_and_launch_for_op("sub_a_bt", a, b, c, n, m) }
742
743 fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
744 { self.check_and_launch_for_op("sub_at_bt", a, b, c, n, m) }
745
746 fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
747 {
748 if self.has_cublas {
749 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, false)
750 } else {
751 self.check_and_launch_for_mul("mul_a_b", a, b, c, n, m, l)
752 }
753 }
754
755 fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
756 {
757 if self.has_cublas {
758 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, false)
759 } else {
760 self.check_and_launch_for_mul("mul_at_b", a, b, c, n, m, l)
761 }
762 }
763
764 fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
765 {
766 if self.has_cublas {
767 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, false, true)
768 } else {
769 self.check_and_launch_for_mul("mul_a_bt", a, b, c, n, m, l)
770 }
771 }
772
773 fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
774 {
775 if self.has_cublas {
776 self.check_and_launch_cublas_for_mul(a, b, c, n, m, l, true, true)
777 } else {
778 self.check_and_launch_for_mul("mul_at_bt", a, b, c, n, m, l)
779 }
780 }
781
782 fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
783 { self.check_and_launch_for_op("mul_a_b_for_elems", a, b, c, n, m) }
784
785 fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
786 { self.check_and_launch_for_op("mul_at_b_for_elems", a, b, c, n, m) }
787
788 fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
789 { self.check_and_launch_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
790
791 fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
792 { self.check_and_launch_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
793
794 fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
795 { self.check_and_launch_for_op("div_a_b_for_elems", a, b, c, n, m) }
796
797 fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
798 { self.check_and_launch_for_op("div_at_b_for_elems", a, b, c, n, m) }
799
800 fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
801 { self.check_and_launch_for_op("div_a_bt_for_elems", a, b, c, n, m) }
802
803 fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
804 { self.check_and_launch_for_op("div_at_bt_for_elems", a, b, c, n, m) }
805
806 fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
807 { self.check_and_launch_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
808
809 fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
810 { self.check_and_launch_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
811
812 fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
813 { self.check_and_launch_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
814
815 fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
816 { self.check_and_launch_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
817
818 fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
819 { self.check_and_launch_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
820
821 fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
822 { self.check_and_launch_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
823
824 fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
825 { self.check_and_launch_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
826
827 fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
828 { self.check_and_launch_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
829
830 fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
831 { self.check_and_launch_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
832
833 fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
834 { self.check_and_launch_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
835
836 fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
837 { self.check_and_launch_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
838
839 fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
840 { self.check_and_launch_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
841
842 fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
843 { self.check_and_launch_for_fun("sigmoid_a", a, b, n, m) }
844
845 fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
846 { self.check_and_launch_for_fun("sigmoid_at", a, b, n, m) }
847
848 fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
849 { self.check_and_launch_for_fun("tanh_a", a, b, n, m) }
850
851 fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
852 { self.check_and_launch_for_fun("tanh_at", a, b, n, m) }
853
854 fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
855 { self.check_and_launch_for_fun_and_tiles("softmax_a", a, b, n, m) }
856
857 fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
858 { self.check_and_launch_for_fun_and_tiles("softmax_at", a, b, n, m) }
859
860 fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
861 { self.check_and_launch_for_repeat_col("repeat_col_a", a, b, n, m) }
862
863 fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
864 { self.check_and_launch_for_repeat_row("repeat_row_a", a, b, n, m) }
865}
866
867#[cfg(test)]
868mod tests;