tensorgraph_math/blas.rs
1use tensorgraph_sys::{device::Device, ptr::DPtr};
2
3mod cpu;
4
5#[cfg(feature = "cublas")]
6pub mod cublas;
7
8#[repr(u8)]
9#[derive(Clone, Copy, PartialEq, Debug)]
10/// Represents how a matrix can be represented internally.
11pub enum MatrixOp {
12 NoTrans = b'N',
13 Trans = b'T',
14 // ConjTrans = b'C',
15}
16
17/// A context needed for running BLAS operations
18pub trait BLASContext: Clone {
19 type Device: Device;
20}
21
22/// The default blas context for a device
23pub trait DefaultBLASContext: Device {
24 type Context: BLASContext<Device = Self> + Default;
25}
26
27/// BLAS Level 1 operations (Vector only)
28#[allow(clippy::too_many_arguments)]
29pub trait BLAS1<C: BLASContext>: Sized + Copy {
30 /// Computes
31 /// > X = alpha * X
32 ///
33 /// # Safety
34 /// This is often a call across an FFI barrier, so the links or devices need to be
35 /// running and may perform UB unchecked by rust
36 unsafe fn scal(
37 ctx: C,
38 n: i32,
39 alpha: Self,
40 x: DPtr<Self, C::Device>,
41 incx: i32,
42 );
43
44 /// Computes
45 /// > Y = alpha * X + Y
46 ///
47 /// # Safety
48 /// This is often a call across an FFI barrier, so the links or devices need to be
49 /// running and may perform UB unchecked by rust
50 unsafe fn axpy(
51 ctx: C,
52 n: i32,
53 alpha: Self,
54 x: DPtr<Self, C::Device>,
55 incx: i32,
56 y: DPtr<Self, C::Device>,
57 incy: i32,
58 );
59
60 /// Computes the vector dot product
61 ///
62 /// # Safety
63 /// This is often a call across an FFI barrier, so the links or devices need to be
64 /// running and may perform UB unchecked by rust
65 unsafe fn dot(
66 ctx: C,
67 n: i32,
68 x: DPtr<Self, C::Device>,
69 incx: i32,
70 y: DPtr<Self, C::Device>,
71 incy: i32,
72 ) -> Self;
73}
74
75/// BLAS Level 2 operations (Matrix-Vector)
76#[allow(clippy::too_many_arguments)]
77pub trait BLAS2<C: BLASContext>: Sized + Copy {
78 /// Compute the **Ge**neralised **M**atrix-**V**ector multiplication:
79 /// > y = alpha * Ax + beta * y
80 ///
81 /// # Safety
82 /// This is often a call across an FFI barrier, so the links or devices need to be
83 /// running and may perform UB unchecked by rust
84 unsafe fn gemv(
85 ctx: C,
86 trans: MatrixOp,
87 m: i32,
88 n: i32,
89 alpha: Self,
90 a: DPtr<Self, C::Device>,
91 lda: i32,
92 x: DPtr<Self, C::Device>,
93 incx: i32,
94 beta: Self,
95 y: DPtr<Self, C::Device>,
96 incy: i32,
97 );
98}
99
100/// BLAS Level 3 operations (Matrix-Matrix)
101#[allow(clippy::too_many_arguments)]
102pub trait BLAS3<C: BLASContext>: Sized + Copy {
103 /// Compute the **Ge**neralised **M**atrix-**M**atrix multiplication:
104 /// > C = alpha * AB + beta * C
105 ///
106 /// # Safety
107 /// This is often a call across an FFI barrier, so the links or devices need to be
108 /// running and may perform UB unchecked by rust
109 unsafe fn gemm(
110 ctx: C,
111 transa: MatrixOp,
112 transb: MatrixOp,
113 m: i32,
114 n: i32,
115 k: i32,
116 alpha: Self,
117 a: DPtr<Self, C::Device>,
118 lda: i32,
119 b: DPtr<Self, C::Device>,
120 ldb: i32,
121 beta: Self,
122 c: DPtr<Self, C::Device>,
123 ldc: i32,
124 );
125}
126
127/// A complete BLAS library, levels 1, 2 and 3
128pub trait BLAS<C: BLASContext>: BLAS1<C> + BLAS2<C> + BLAS3<C> {}
129impl<F, C: BLASContext> BLAS<C> for F where F: BLAS1<C> + BLAS2<C> + BLAS3<C> {}