tensorgraph_math/
blas.rs

1use tensorgraph_sys::{device::Device, ptr::DPtr};
2
3mod cpu;
4
5#[cfg(feature = "cublas")]
6pub mod cublas;
7
8#[repr(u8)]
9#[derive(Clone, Copy, PartialEq, Debug)]
10/// Represents how a matrix can be represented internally.
11pub enum MatrixOp {
12    NoTrans = b'N',
13    Trans = b'T',
14    // ConjTrans = b'C',
15}
16
17/// A context needed for running BLAS operations
18pub trait BLASContext: Clone {
19    type Device: Device;
20}
21
22/// The default blas context for a device
23pub trait DefaultBLASContext: Device {
24    type Context: BLASContext<Device = Self> + Default;
25}
26
27/// BLAS Level 1 operations (Vector only)
28#[allow(clippy::too_many_arguments)]
29pub trait BLAS1<C: BLASContext>: Sized + Copy {
30    /// Computes
31    /// > X = alpha * X
32    ///
33    /// # Safety
34    /// This is often a call across an FFI barrier, so the links or devices need to be
35    /// running and may perform UB unchecked by rust
36    unsafe fn scal(
37        ctx: C,
38        n: i32,
39        alpha: Self,
40        x: DPtr<Self, C::Device>,
41        incx: i32,
42    );
43
44    /// Computes
45    /// > Y = alpha * X + Y
46    ///
47    /// # Safety
48    /// This is often a call across an FFI barrier, so the links or devices need to be
49    /// running and may perform UB unchecked by rust
50    unsafe fn axpy(
51        ctx: C,
52        n: i32,
53        alpha: Self,
54        x: DPtr<Self, C::Device>,
55        incx: i32,
56        y: DPtr<Self, C::Device>,
57        incy: i32,
58    );
59
60    /// Computes the vector dot product
61    ///
62    /// # Safety
63    /// This is often a call across an FFI barrier, so the links or devices need to be
64    /// running and may perform UB unchecked by rust
65    unsafe fn dot(
66        ctx: C,
67        n: i32,
68        x: DPtr<Self, C::Device>,
69        incx: i32,
70        y: DPtr<Self, C::Device>,
71        incy: i32,
72    ) -> Self;
73}
74
75/// BLAS Level 2 operations (Matrix-Vector)
76#[allow(clippy::too_many_arguments)]
77pub trait BLAS2<C: BLASContext>: Sized + Copy {
78    /// Compute the **Ge**neralised **M**atrix-**V**ector multiplication:
79    /// > y = alpha * Ax + beta * y
80    ///
81    /// # Safety
82    /// This is often a call across an FFI barrier, so the links or devices need to be
83    /// running and may perform UB unchecked by rust
84    unsafe fn gemv(
85        ctx: C,
86        trans: MatrixOp,
87        m: i32,
88        n: i32,
89        alpha: Self,
90        a: DPtr<Self, C::Device>,
91        lda: i32,
92        x: DPtr<Self, C::Device>,
93        incx: i32,
94        beta: Self,
95        y: DPtr<Self, C::Device>,
96        incy: i32,
97    );
98}
99
100/// BLAS Level 3 operations (Matrix-Matrix)
101#[allow(clippy::too_many_arguments)]
102pub trait BLAS3<C: BLASContext>: Sized + Copy {
103    /// Compute the **Ge**neralised **M**atrix-**M**atrix multiplication:
104    /// > C = alpha * AB + beta * C
105    ///
106    /// # Safety
107    /// This is often a call across an FFI barrier, so the links or devices need to be
108    /// running and may perform UB unchecked by rust
109    unsafe fn gemm(
110        ctx: C,
111        transa: MatrixOp,
112        transb: MatrixOp,
113        m: i32,
114        n: i32,
115        k: i32,
116        alpha: Self,
117        a: DPtr<Self, C::Device>,
118        lda: i32,
119        b: DPtr<Self, C::Device>,
120        ldb: i32,
121        beta: Self,
122        c: DPtr<Self, C::Device>,
123        ldc: i32,
124    );
125}
126
127/// A complete BLAS library, levels 1, 2 and 3
128pub trait BLAS<C: BLASContext>: BLAS1<C> + BLAS2<C> + BLAS3<C> {}
129impl<F, C: BLASContext> BLAS<C> for F where F: BLAS1<C> + BLAS2<C> + BLAS3<C> {}