1use std::mem::MaybeUninit;
2
3use num_traits::{One, Zero};
4use tensorgraph_sys::{
5 device::{DefaultDeviceAllocator, Device, DeviceAllocator},
6 DefaultVec, Vec, View,
7};
8
9use crate::{
10 blas::{BLASContext, DefaultBLASContext, MatrixOp, BLAS, BLAS2, BLAS3},
11 storage::Storage,
12};
13
14use super::{Slice, Tensor, UninitVector, Vector, VectorView, VectorViewMut, ViewOf};
15
16pub type Matrix<S> = Tensor<S, [usize; 2]>;
18
19pub type MatrixView<'a, T, D> = Matrix<&'a Slice<T, D>>;
21
22pub type MatrixViewMut<'a, T, D> = Matrix<&'a mut Slice<T, D>>;
24
25pub type UninitMatrix<'a, T, D> = MatrixViewMut<'a, MaybeUninit<T>, D>;
27
28impl<S: Storage> Matrix<S> {
29 pub fn dot(&self, rhs: Vector<&ViewOf<S>>) -> Vector<DefaultVec<S::T, S::Device>>
31 where
32 S::Device: DefaultDeviceAllocator + DefaultBLASContext,
33 S::T: Zero + One + BLAS<<S::Device as DefaultBLASContext>::Context>,
34 {
35 self.dot_using(rhs, Default::default())
36 }
37
38 pub fn dot_using<C: BLASContext<Device = S::Device>>(
40 &self,
41 rhs: Vector<&ViewOf<S>>,
42 ctx: C,
43 ) -> Vector<DefaultVec<S::T, S::Device>>
44 where
45 S::Device: DefaultDeviceAllocator,
46 S::T: Zero + One + BLAS<C>,
47 {
48 self.dot_into(rhs, ctx, Default::default())
49 }
50
51 pub fn dot_into<C: BLASContext<Device = S::Device>, A: DeviceAllocator<Device = S::Device>>(
53 &self,
54 rhs: Vector<&ViewOf<S>>,
55 ctx: C,
56 alloc: A,
57 ) -> Vector<Vec<S::T, A>>
58 where
59 S::T: Zero + One + BLAS<C>,
60 {
61 let rows = self.shape[0];
62 let mut v = Vec::with_capacity_in(rows, alloc);
63 unsafe {
64 let uninit = Vector::from_shape([rows], &mut v.space_capacity_mut()[..rows]);
65
66 gemv_uninit_ctx(ctx, S::T::one(), self.view(), rhs, uninit);
67
68 v.set_len(rows);
69 }
70 Vector::from_shape([rows], v)
71 }
72}
73
74impl<S: Storage> Matrix<S> {
75 pub fn matmul(&self, rhs: Matrix<&ViewOf<S>>) -> Matrix<DefaultVec<S::T, S::Device>>
77 where
78 S::Device: DefaultDeviceAllocator + DefaultBLASContext,
79 S::T: Zero + One + BLAS<<S::Device as DefaultBLASContext>::Context>,
80 {
81 self.matmul_using(rhs, Default::default())
82 }
83
84 pub fn matmul_using<C: BLASContext<Device = S::Device>>(
86 &self,
87 rhs: Matrix<&ViewOf<S>>,
88 ctx: C,
89 ) -> Matrix<DefaultVec<S::T, S::Device>>
90 where
91 S::Device: DefaultDeviceAllocator,
92 S::T: Zero + One + BLAS<C>,
93 {
94 self.matmul_into(rhs, ctx, Default::default())
95 }
96
97 pub fn matmul_into<C: BLASContext<Device = S::Device>, A: DeviceAllocator<Device = S::Device>>(
99 &self,
100 rhs: Matrix<&ViewOf<S>>,
101 ctx: C,
102 alloc: A,
103 ) -> Matrix<Vec<S::T, A>>
104 where
105 S::T: Zero + One + BLAS<C>,
106 {
107 let rows = self.shape[0];
108 let cols = rhs.shape[1];
109 let mut v = Vec::with_capacity_in(rows * cols, alloc);
110 unsafe {
111 let uninit =
112 Matrix::from_shape([rows, cols], &mut v.space_capacity_mut()[..rows * cols]);
113
114 gemm_uninit_ctx(ctx, S::T::one(), self.view(), rhs, uninit);
115
116 v.set_len(rows * cols);
117 }
118 Matrix::from_shape([rows, cols], v)
119 }
120}
121
122pub fn gemv_uninit<F: BLAS2<D::Context> + Zero, D: DefaultBLASContext>(
133 alpha: F,
134 a: MatrixView<F, D>,
135 x: VectorView<F, D>,
136 y: UninitVector<F, D>,
137) {
138 gemv_uninit_ctx(D::Context::default(), alpha, a, x, y);
139}
140
141pub fn gemv_uninit_ctx<F: BLAS2<C> + Zero, C: BLASContext<Device = D>, D: Device>(
150 ctx: C,
151 alpha: F,
152 a: MatrixView<F, D>,
153 x: VectorView<F, D>,
154 y: UninitVector<F, D>,
155) {
156 unsafe { gemv_ctx(ctx, alpha, a, x, F::zero(), y.assume_init()) }
159}
160
161pub fn gemv<F: BLAS2<D::Context> + Zero, D: DefaultBLASContext>(
172 alpha: F,
173 a: MatrixView<F, D>,
174 x: VectorView<F, D>,
175 beta: F,
176 y: VectorViewMut<F, D>,
177) {
178 gemv_ctx(D::Context::default(), alpha, a, x, beta, y);
179}
180
181#[allow(
190 clippy::cast_possible_wrap,
191 clippy::cast_possible_truncation,
192 clippy::needless_pass_by_value
193)]
194pub fn gemv_ctx<F: BLAS2<C> + Zero, C: BLASContext<Device = D>, D: Device>(
195 ctx: C,
196 alpha: F,
197 a: MatrixView<F, D>,
198 x: VectorView<F, D>,
199 beta: F,
200 y: VectorViewMut<F, D>,
201) {
202 let [rowsa, colsa] = a.shape;
203 let [rowsx] = x.shape;
204 let [rowsy] = y.shape;
205 assert_eq!(rowsa, rowsy);
206 assert_eq!(colsa, rowsx);
207
208 let m = rowsa as i32;
209 let n = colsa as i32;
210
211 let (trans, lda) = lead(a.strides);
212 let incx = x.strides[0] as i32;
213 let incy = y.strides[0] as i32;
214
215 unsafe {
216 F::gemv(
217 ctx,
218 trans,
219 m,
220 n,
221 alpha,
222 a.data.as_ref().as_ptr(),
223 lda,
224 x.data.as_ref().as_ptr(),
225 incx,
226 beta,
227 y.data.as_ptr(),
228 incy,
229 );
230 }
231}
232
233pub fn gemm_uninit<F: BLAS3<D::Context> + Zero, D: DefaultBLASContext>(
244 alpha: F,
245 a: MatrixView<F, D>,
246 b: MatrixView<F, D>,
247 c: UninitMatrix<F, D>,
248) {
249 gemm_uninit_ctx(D::Context::default(), alpha, a, b, c);
250}
251
252pub fn gemm_uninit_ctx<F: BLAS3<C> + Zero, C: BLASContext<Device = D>, D: Device>(
261 ctx: C,
262 alpha: F,
263 a: MatrixView<F, D>,
264 b: MatrixView<F, D>,
265 c: UninitMatrix<F, D>,
266) {
267 unsafe { gemm_ctx(ctx, alpha, a, b, F::zero(), c.assume_init()) }
270}
271
272pub fn gemm<F: BLAS3<D::Context> + Zero, D: DefaultBLASContext>(
283 alpha: F,
284 a: MatrixView<F, D>,
285 b: MatrixView<F, D>,
286 beta: F,
287 c: MatrixViewMut<F, D>,
288) {
289 gemm_ctx(D::Context::default(), alpha, a, b, beta, c);
290}
291
292#[allow(
301 clippy::cast_possible_wrap,
302 clippy::cast_possible_truncation,
303 clippy::needless_pass_by_value
304)]
305pub fn gemm_ctx<F: BLAS3<C> + Zero, C: BLASContext<Device = D>, D: Device>(
306 ctx: C,
307 alpha: F,
308 a: MatrixView<F, D>,
309 b: MatrixView<F, D>,
310 beta: F,
311 c: MatrixViewMut<F, D>,
312) {
313 let [rowsa, colsa] = a.shape;
314 let [rowsb, colsb] = b.shape;
315 let [rowsc, colsc] = c.shape;
316 assert_eq!(rowsa, rowsc);
317 assert_eq!(colsb, colsc);
318 assert_eq!(colsa, rowsb);
319
320 let m = rowsa as i32;
321 let k = rowsb as i32;
322 let n = colsb as i32;
323
324 let (transa, lda) = lead(a.strides);
325 let (transb, ldb) = lead(b.strides);
326
327 assert_eq!(c.strides[0], 1);
329 let ldc = c.strides[1] as i32;
330
331 unsafe {
332 F::gemm(
333 ctx,
334 transa,
335 transb,
336 m,
337 n,
338 k,
339 alpha,
340 a.data.as_ref().as_ptr(),
341 lda,
342 b.data.as_ref().as_ptr(),
343 ldb,
344 beta,
345 c.data.as_ptr(),
346 ldc,
347 );
348 }
349}
350
351#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
352fn lead(s: [usize; 2]) -> (MatrixOp, i32) {
353 if s[0] == 1 {
354 (MatrixOp::NoTrans, s[1] as i32)
355 } else if s[1] == 1 {
356 (MatrixOp::Trans, s[0] as i32)
357 } else {
358 panic!("one of the strides must be 1 (contiguous)")
359 }
360}