tensorgraph_math/tensor/
matrix.rs

1use std::mem::MaybeUninit;
2
3use num_traits::{One, Zero};
4use tensorgraph_sys::{
5    device::{DefaultDeviceAllocator, Device, DeviceAllocator},
6    DefaultVec, Vec, View,
7};
8
9use crate::{
10    blas::{BLASContext, DefaultBLASContext, MatrixOp, BLAS, BLAS2, BLAS3},
11    storage::Storage,
12};
13
14use super::{Slice, Tensor, UninitVector, Vector, VectorView, VectorViewMut, ViewOf};
15
16/// A 2-dimensional tensor
17pub type Matrix<S> = Tensor<S, [usize; 2]>;
18
19/// A 'view' of a matrix, Like `&[T]` is to `Vec<T>`
20pub type MatrixView<'a, T, D> = Matrix<&'a Slice<T, D>>;
21
22/// A 'mut view' of a matrix, Like `&mut [T]` is to `Vec<T>`
23pub type MatrixViewMut<'a, T, D> = Matrix<&'a mut Slice<T, D>>;
24
25/// An uninit matrix. Contents are mutable and specified as [`MaybeUninit`].
26pub type UninitMatrix<'a, T, D> = MatrixViewMut<'a, MaybeUninit<T>, D>;
27
28impl<S: Storage> Matrix<S> {
29    /// Matrix-vector multiplication
30    pub fn dot(&self, rhs: Vector<&ViewOf<S>>) -> Vector<DefaultVec<S::T, S::Device>>
31    where
32        S::Device: DefaultDeviceAllocator + DefaultBLASContext,
33        S::T: Zero + One + BLAS<<S::Device as DefaultBLASContext>::Context>,
34    {
35        self.dot_using(rhs, Default::default())
36    }
37
38    /// Matrix-vector multiplication, using the specified [`BLASContext`]
39    pub fn dot_using<C: BLASContext<Device = S::Device>>(
40        &self,
41        rhs: Vector<&ViewOf<S>>,
42        ctx: C,
43    ) -> Vector<DefaultVec<S::T, S::Device>>
44    where
45        S::Device: DefaultDeviceAllocator,
46        S::T: Zero + One + BLAS<C>,
47    {
48        self.dot_into(rhs, ctx, Default::default())
49    }
50
51    /// Matrix-vector multiplication, using the provided [`DeviceAllocator`], using the specified [`BLASContext`]
52    pub fn dot_into<C: BLASContext<Device = S::Device>, A: DeviceAllocator<Device = S::Device>>(
53        &self,
54        rhs: Vector<&ViewOf<S>>,
55        ctx: C,
56        alloc: A,
57    ) -> Vector<Vec<S::T, A>>
58    where
59        S::T: Zero + One + BLAS<C>,
60    {
61        let rows = self.shape[0];
62        let mut v = Vec::with_capacity_in(rows, alloc);
63        unsafe {
64            let uninit = Vector::from_shape([rows], &mut v.space_capacity_mut()[..rows]);
65
66            gemv_uninit_ctx(ctx, S::T::one(), self.view(), rhs, uninit);
67
68            v.set_len(rows);
69        }
70        Vector::from_shape([rows], v)
71    }
72}
73
74impl<S: Storage> Matrix<S> {
75    /// Multiply two matricies together.
76    pub fn matmul(&self, rhs: Matrix<&ViewOf<S>>) -> Matrix<DefaultVec<S::T, S::Device>>
77    where
78        S::Device: DefaultDeviceAllocator + DefaultBLASContext,
79        S::T: Zero + One + BLAS<<S::Device as DefaultBLASContext>::Context>,
80    {
81        self.matmul_using(rhs, Default::default())
82    }
83
84    /// Multiply two matricies together, using the specified [`BLASContext`]
85    pub fn matmul_using<C: BLASContext<Device = S::Device>>(
86        &self,
87        rhs: Matrix<&ViewOf<S>>,
88        ctx: C,
89    ) -> Matrix<DefaultVec<S::T, S::Device>>
90    where
91        S::Device: DefaultDeviceAllocator,
92        S::T: Zero + One + BLAS<C>,
93    {
94        self.matmul_into(rhs, ctx, Default::default())
95    }
96
97    /// Multiply two matricies together, using the provided [`DeviceAllocator`], using the specified [`BLASContext`]
98    pub fn matmul_into<C: BLASContext<Device = S::Device>, A: DeviceAllocator<Device = S::Device>>(
99        &self,
100        rhs: Matrix<&ViewOf<S>>,
101        ctx: C,
102        alloc: A,
103    ) -> Matrix<Vec<S::T, A>>
104    where
105        S::T: Zero + One + BLAS<C>,
106    {
107        let rows = self.shape[0];
108        let cols = rhs.shape[1];
109        let mut v = Vec::with_capacity_in(rows * cols, alloc);
110        unsafe {
111            let uninit =
112                Matrix::from_shape([rows, cols], &mut v.space_capacity_mut()[..rows * cols]);
113
114            gemm_uninit_ctx(ctx, S::T::one(), self.view(), rhs, uninit);
115
116            v.set_len(rows * cols);
117        }
118        Matrix::from_shape([rows, cols], v)
119    }
120}
121
122/// Performs the basic matrix-vector multiplication operation.
123/// > y = alpha * Ax.
124///
125/// Uses the default [`BLASContext`] for the device.
126///
127/// # Panics
128/// If the shapes of the matricies do not match the following pattern:
129/// * A = (M, N)
130/// * X = (N)
131/// * Y = (M)
132pub fn gemv_uninit<F: BLAS2<D::Context> + Zero, D: DefaultBLASContext>(
133    alpha: F,
134    a: MatrixView<F, D>,
135    x: VectorView<F, D>,
136    y: UninitVector<F, D>,
137) {
138    gemv_uninit_ctx(D::Context::default(), alpha, a, x, y);
139}
140
141/// Performs the basic matrix-vector multiplication operation.
142/// > y = alpha * Ax.
143///
144/// # Panics
145/// If the shapes of the matricies do not match the following pattern:
146/// * A = (M, N)
147/// * X = (N)
148/// * Y = (M)
149pub fn gemv_uninit_ctx<F: BLAS2<C> + Zero, C: BLASContext<Device = D>, D: Device>(
150    ctx: C,
151    alpha: F,
152    a: MatrixView<F, D>,
153    x: VectorView<F, D>,
154    y: UninitVector<F, D>,
155) {
156    // Safety:
157    // Specifying beta == 0.0 should allow c to be safely read while uninitialised
158    unsafe { gemv_ctx(ctx, alpha, a, x, F::zero(), y.assume_init()) }
159}
160
161/// Performs the basic matrix-vector multiplication operation.
162/// > y = alpha * Ax + beta * y.
163///
164/// Uses the default [`BLASContext`] for the device.
165///
166/// # Panics
167/// If the shapes of the matricies do not match the following pattern:
168/// * A = (M, N)
169/// * X = (N)
170/// * Y = (M)
171pub fn gemv<F: BLAS2<D::Context> + Zero, D: DefaultBLASContext>(
172    alpha: F,
173    a: MatrixView<F, D>,
174    x: VectorView<F, D>,
175    beta: F,
176    y: VectorViewMut<F, D>,
177) {
178    gemv_ctx(D::Context::default(), alpha, a, x, beta, y);
179}
180
181/// Performs the basic matrix-vector multiplication operation.
182/// > y = alpha * Ax + beta * y.
183///
184/// # Panics
185/// If the shapes of the matricies do not match the following pattern:
186/// * A = (M, N)
187/// * X = (N)
188/// * Y = (M)
189#[allow(
190    clippy::cast_possible_wrap,
191    clippy::cast_possible_truncation,
192    clippy::needless_pass_by_value
193)]
194pub fn gemv_ctx<F: BLAS2<C> + Zero, C: BLASContext<Device = D>, D: Device>(
195    ctx: C,
196    alpha: F,
197    a: MatrixView<F, D>,
198    x: VectorView<F, D>,
199    beta: F,
200    y: VectorViewMut<F, D>,
201) {
202    let [rowsa, colsa] = a.shape;
203    let [rowsx] = x.shape;
204    let [rowsy] = y.shape;
205    assert_eq!(rowsa, rowsy);
206    assert_eq!(colsa, rowsx);
207
208    let m = rowsa as i32;
209    let n = colsa as i32;
210
211    let (trans, lda) = lead(a.strides);
212    let incx = x.strides[0] as i32;
213    let incy = y.strides[0] as i32;
214
215    unsafe {
216        F::gemv(
217            ctx,
218            trans,
219            m,
220            n,
221            alpha,
222            a.data.as_ref().as_ptr(),
223            lda,
224            x.data.as_ref().as_ptr(),
225            incx,
226            beta,
227            y.data.as_ptr(),
228            incy,
229        );
230    }
231}
232
233/// Performs the basic matmul operation.
234/// > C = alpha * AB.
235///
236/// Uses the default [`BLASContext`] for the device.
237///
238/// # Panics
239/// If the shapes of the matricies do not match the following pattern:
240/// * A = (M, K)
241/// * B = (K, N)
242/// * C = (M, N)
243pub fn gemm_uninit<F: BLAS3<D::Context> + Zero, D: DefaultBLASContext>(
244    alpha: F,
245    a: MatrixView<F, D>,
246    b: MatrixView<F, D>,
247    c: UninitMatrix<F, D>,
248) {
249    gemm_uninit_ctx(D::Context::default(), alpha, a, b, c);
250}
251
252/// Performs the basic matmul operation.
253/// > C = alpha * AB.
254///
255/// # Panics
256/// If the shapes of the matricies do not match the following pattern:
257/// * A = (M, K)
258/// * B = (K, N)
259/// * C = (M, N)
260pub fn gemm_uninit_ctx<F: BLAS3<C> + Zero, C: BLASContext<Device = D>, D: Device>(
261    ctx: C,
262    alpha: F,
263    a: MatrixView<F, D>,
264    b: MatrixView<F, D>,
265    c: UninitMatrix<F, D>,
266) {
267    // Safety:
268    // Specifying beta == 0.0 should allow c to be safely read while uninitialised
269    unsafe { gemm_ctx(ctx, alpha, a, b, F::zero(), c.assume_init()) }
270}
271
272/// Performs the basic matmul operation.
273/// > C = alpha * AB + beta * C.
274///
275/// Uses the default [`BLASContext`] for the device.
276///
277/// # Panics
278/// If the shapes of the matricies do not match the following pattern:
279/// * A = (M, K)
280/// * B = (K, N)
281/// * C = (M, N)
282pub fn gemm<F: BLAS3<D::Context> + Zero, D: DefaultBLASContext>(
283    alpha: F,
284    a: MatrixView<F, D>,
285    b: MatrixView<F, D>,
286    beta: F,
287    c: MatrixViewMut<F, D>,
288) {
289    gemm_ctx(D::Context::default(), alpha, a, b, beta, c);
290}
291
292/// Performs the basic matmul operation.
293/// > C = alpha * AB + beta * C.
294///
295/// # Panics
296/// If the shapes of the matricies do not match the following pattern:
297/// * A = (M, K)
298/// * B = (K, N)
299/// * C = (M, N)
300#[allow(
301    clippy::cast_possible_wrap,
302    clippy::cast_possible_truncation,
303    clippy::needless_pass_by_value
304)]
305pub fn gemm_ctx<F: BLAS3<C> + Zero, C: BLASContext<Device = D>, D: Device>(
306    ctx: C,
307    alpha: F,
308    a: MatrixView<F, D>,
309    b: MatrixView<F, D>,
310    beta: F,
311    c: MatrixViewMut<F, D>,
312) {
313    let [rowsa, colsa] = a.shape;
314    let [rowsb, colsb] = b.shape;
315    let [rowsc, colsc] = c.shape;
316    assert_eq!(rowsa, rowsc);
317    assert_eq!(colsb, colsc);
318    assert_eq!(colsa, rowsb);
319
320    let m = rowsa as i32;
321    let k = rowsb as i32;
322    let n = colsb as i32;
323
324    let (transa, lda) = lead(a.strides);
325    let (transb, ldb) = lead(b.strides);
326
327    // C must not be transposed
328    assert_eq!(c.strides[0], 1);
329    let ldc = c.strides[1] as i32;
330
331    unsafe {
332        F::gemm(
333            ctx,
334            transa,
335            transb,
336            m,
337            n,
338            k,
339            alpha,
340            a.data.as_ref().as_ptr(),
341            lda,
342            b.data.as_ref().as_ptr(),
343            ldb,
344            beta,
345            c.data.as_ptr(),
346            ldc,
347        );
348    }
349}
350
351#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
352fn lead(s: [usize; 2]) -> (MatrixOp, i32) {
353    if s[0] == 1 {
354        (MatrixOp::NoTrans, s[1] as i32)
355    } else if s[1] == 1 {
356        (MatrixOp::Trans, s[0] as i32)
357    } else {
358        panic!("one of the strides must be 1 (contiguous)")
359    }
360}